[router] Refactor StopSequenceDecoder to Use Sequence for Incremental Decoding (#11676)

This commit is contained in:
Simo Lin
2025-10-15 16:31:03 -07:00
committed by GitHub
parent 2479b89405
commit f5d30dae89
2 changed files with 76 additions and 95 deletions

View File

@@ -1,3 +1,4 @@
use super::sequence::Sequence;
use super::traits::{self, TokenIdType};
use anyhow::Result;
use std::collections::HashSet;
@@ -57,19 +58,13 @@ impl StopSequenceConfig {
/// Decoder that handles stop sequences
pub struct StopSequenceDecoder {
tokenizer: Arc<dyn traits::Tokenizer>,
/// Sequence for incremental decoding (replaces token_buffer + offsets)
sequence: Sequence,
config: StopSequenceConfig,
/// Buffer for partial matches (the "jail")
jail_buffer: String,
/// Accumulated tokens
token_buffer: Vec<TokenIdType>,
/// Offset where the prefix text starts (for context)
prefix_offset: usize,
/// Offset marking the end of previously decoded text
read_offset: usize,
/// Whether we've stopped
stopped: bool,
skip_special_tokens: bool,
}
impl StopSequenceDecoder {
@@ -80,14 +75,10 @@ impl StopSequenceDecoder {
skip_special_tokens: bool,
) -> Self {
StopSequenceDecoder {
tokenizer,
sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
config,
jail_buffer: String::new(),
token_buffer: Vec::new(),
prefix_offset: 0,
read_offset: 0,
stopped: false,
skip_special_tokens,
}
}
@@ -115,57 +106,24 @@ impl StopSequenceDecoder {
// Include jailed text plus the stop token
let stop_text = self
.tokenizer
.decode(&[token_id], self.skip_special_tokens)?;
.sequence
.tokenizer()
.decode(&[token_id], self.sequence.skip_special_tokens())?;
let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
// Add token to buffer
self.token_buffer.push(token_id);
// Use Sequence for incremental decoding
let new_text = self.sequence.append_token(token_id)?;
// Use incremental decoding like DecodeStream
// First decode the previous context (what we've already output)
let prefix_text = if self.read_offset > self.prefix_offset {
self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?
} else {
String::new()
};
self.jail_buffer.push_str(&new_text);
// Now decode from prefix to current position
let new_full_text = self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..],
self.skip_special_tokens,
)?;
// Check for incomplete UTF-8 sequence
if new_full_text.ends_with("<EFBFBD>") {
// Wait for more tokens to complete the sequence
return Ok(SequenceDecoderOutput::Held);
}
// Calculate only the NEW text since last successful decode
let new_text = if new_full_text.len() > prefix_text.len() {
&new_full_text[prefix_text.len()..]
} else {
// No new text produced (can happen with special tokens)
return Ok(SequenceDecoderOutput::Held);
};
// Combine jail buffer with new text for checking
let check_text = format!("{}{}", self.jail_buffer, new_text);
// Check for complete stop sequences
// Check for hidden stop sequences
for stop_seq in &self.config.stop_sequences {
if let Some(pos) = check_text.find(stop_seq) {
if let Some(pos) = self.jail_buffer.find(stop_seq) {
self.stopped = true;
// Output text before the stop sequence
let output = check_text[..pos].to_string();
let output = self.jail_buffer[..pos].to_string();
self.jail_buffer.clear();
return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped
@@ -177,58 +135,54 @@ impl StopSequenceDecoder {
// Check for visible stop sequences
for stop_seq in &self.config.visible_stop_sequences {
if let Some(pos) = check_text.find(stop_seq) {
if let Some(pos) = self.jail_buffer.find(stop_seq) {
self.stopped = true;
// Include the stop sequence in output
let end_pos = pos + stop_seq.len();
let output = check_text[..end_pos].to_string();
let output = self.jail_buffer[..end_pos].to_string();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
}
// Check for partial matches at the end of check_text
let mut partial_match_len = 0;
// Check for partial matches: is the end of jail_buffer the start of any stop_seq?
// This handles stop sequences split across tokens
let mut longest_partial = 0;
for stop_seq in self
.config
.stop_sequences
.iter()
.chain(&self.config.visible_stop_sequences)
{
// Check all possible suffixes that could be a prefix of stop_seq
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
let suffix = &check_text[check_text.len() - i..];
// Check suffixes of jail_buffer that match prefixes of stop_seq
// We check up to stop_seq.len() - 1 to avoid rechecking exact matches
let max_len = self.jail_buffer.len().min(stop_seq.len() - 1);
for len in 1..=max_len {
let suffix = &self.jail_buffer[self.jail_buffer.len() - len..];
if stop_seq.starts_with(suffix) {
partial_match_len = partial_match_len.max(i);
longest_partial = longest_partial.max(len);
}
}
}
if partial_match_len > 0 {
// Split: output safe text, jail the potential match
let safe_end = check_text.len() - partial_match_len;
let safe_text = &check_text[..safe_end];
self.jail_buffer = check_text[safe_end..].to_string();
if longest_partial > 0 {
// Hold the partial match, flush the rest
let split_pos = self.jail_buffer.len() - longest_partial;
let to_output = self.jail_buffer[..split_pos].to_string();
self.jail_buffer = self.jail_buffer[split_pos..].to_string();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
if safe_text.is_empty() {
if to_output.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
Ok(SequenceDecoderOutput::Text(to_output))
}
} else {
// No partial matches - output everything
self.jail_buffer.clear();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
Ok(SequenceDecoderOutput::Text(check_text))
// No partial matches - flush everything
let output = std::mem::take(&mut self.jail_buffer);
if output.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(output))
}
}
}
@@ -263,9 +217,7 @@ impl StopSequenceDecoder {
/// Reset the decoder state
pub fn reset(&mut self) {
self.jail_buffer.clear();
self.token_buffer.clear();
self.prefix_offset = 0;
self.read_offset = 0;
self.sequence.clear();
self.stopped = false;
}
}