From f5d30dae89fd413cabd2d573c2eed9907d233dcb Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 15 Oct 2025 16:31:03 -0700 Subject: [PATCH] [router] Refactor StopSequenceDecoder to Use Sequence for Incremental Decoding (#11676) --- sgl-router/src/tokenizer/sequence.rs | 45 ++++++++-- sgl-router/src/tokenizer/stop.rs | 126 +++++++++------------------ 2 files changed, 76 insertions(+), 95 deletions(-) diff --git a/sgl-router/src/tokenizer/sequence.rs b/sgl-router/src/tokenizer/sequence.rs index 4a97e4975..f54f73437 100644 --- a/sgl-router/src/tokenizer/sequence.rs +++ b/sgl-router/src/tokenizer/sequence.rs @@ -16,6 +16,9 @@ pub struct Sequence { /// Current position in the sequence read_offset: usize, + + /// Whether to skip special tokens when decoding + skip_special_tokens: bool, } impl std::fmt::Debug for Sequence { @@ -45,22 +48,38 @@ impl std::fmt::Debug for Sequence { impl Sequence { /// Create a new empty sequence pub fn new(tokenizer: Arc) -> Self { + Self::new_with_options(tokenizer, false) + } + + /// Create a new empty sequence with skip_special_tokens option + pub fn new_with_options(tokenizer: Arc, skip_special_tokens: bool) -> Self { Self { tokenizer, token_ids: Vec::new(), prefix_offset: 0, read_offset: 0, + skip_special_tokens, } } /// Create a sequence with initial tokens pub fn with_tokens(tokenizer: Arc, token_ids: Vec) -> Self { + Self::with_tokens_and_options(tokenizer, token_ids, false) + } + + /// Create a sequence with initial tokens and skip_special_tokens option + pub fn with_tokens_and_options( + tokenizer: Arc, + token_ids: Vec, + skip_special_tokens: bool, + ) -> Self { let len = token_ids.len(); Self { tokenizer, token_ids, prefix_offset: 0, read_offset: len, + skip_special_tokens, } } @@ -99,7 +118,9 @@ impl Sequence { // If this is the first token or we're at the beginning, decode everything if self.prefix_offset == 0 && old_read_offset == 0 { - let text = self.tokenizer.decode(&self.token_ids, false)?; + let text = self + .tokenizer + .decode(&self.token_ids, self.skip_special_tokens)?; if text.ends_with("�") { // Incomplete UTF-8 sequence, wait for more tokens return Ok(String::new()); @@ -109,14 +130,16 @@ impl Sequence { } // Decode the text up to the previous position - let prefix_text = self - .tokenizer - .decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?; + let prefix_text = self.tokenizer.decode( + &self.token_ids[self.prefix_offset..old_read_offset], + self.skip_special_tokens, + )?; // Decode the text including the new token - let new_text = self - .tokenizer - .decode(&self.token_ids[self.prefix_offset..], false)?; + let new_text = self.tokenizer.decode( + &self.token_ids[self.prefix_offset..], + self.skip_special_tokens, + )?; // Handle multi-byte character boundaries let mut prefix_text_len = prefix_text.len(); @@ -151,7 +174,8 @@ impl Sequence { /// Decode the entire sequence to text pub fn text(&self) -> Result { - self.tokenizer.decode(&self.token_ids, false) + self.tokenizer + .decode(&self.token_ids, self.skip_special_tokens) } /// Get the prefix offset @@ -163,6 +187,11 @@ impl Sequence { pub fn read_offset(&self) -> usize { self.read_offset } + + /// Get whether special tokens are skipped during decoding + pub fn skip_special_tokens(&self) -> bool { + self.skip_special_tokens + } } #[cfg(test)] diff --git a/sgl-router/src/tokenizer/stop.rs b/sgl-router/src/tokenizer/stop.rs index 0a98f1fa9..3122a0e97 100644 --- a/sgl-router/src/tokenizer/stop.rs +++ b/sgl-router/src/tokenizer/stop.rs @@ -1,3 +1,4 @@ +use super::sequence::Sequence; use super::traits::{self, TokenIdType}; use anyhow::Result; use std::collections::HashSet; @@ -57,19 +58,13 @@ impl StopSequenceConfig { /// Decoder that handles stop sequences pub struct StopSequenceDecoder { - tokenizer: Arc, + /// Sequence for incremental decoding (replaces token_buffer + offsets) + sequence: Sequence, config: StopSequenceConfig, /// Buffer for partial matches (the "jail") jail_buffer: String, - /// Accumulated tokens - token_buffer: Vec, - /// Offset where the prefix text starts (for context) - prefix_offset: usize, - /// Offset marking the end of previously decoded text - read_offset: usize, /// Whether we've stopped stopped: bool, - skip_special_tokens: bool, } impl StopSequenceDecoder { @@ -80,14 +75,10 @@ impl StopSequenceDecoder { skip_special_tokens: bool, ) -> Self { StopSequenceDecoder { - tokenizer, + sequence: Sequence::new_with_options(tokenizer, skip_special_tokens), config, jail_buffer: String::new(), - token_buffer: Vec::new(), - prefix_offset: 0, - read_offset: 0, stopped: false, - skip_special_tokens, } } @@ -115,57 +106,24 @@ impl StopSequenceDecoder { // Include jailed text plus the stop token let stop_text = self - .tokenizer - .decode(&[token_id], self.skip_special_tokens)?; + .sequence + .tokenizer() + .decode(&[token_id], self.sequence.skip_special_tokens())?; let output = format!("{}{}", self.jail_buffer, stop_text); self.jail_buffer.clear(); return Ok(SequenceDecoderOutput::StoppedWithText(output)); } - // Add token to buffer - self.token_buffer.push(token_id); + // Use Sequence for incremental decoding + let new_text = self.sequence.append_token(token_id)?; - // Use incremental decoding like DecodeStream - // First decode the previous context (what we've already output) - let prefix_text = if self.read_offset > self.prefix_offset { - self.tokenizer.decode( - &self.token_buffer[self.prefix_offset..self.read_offset], - self.skip_special_tokens, - )? - } else { - String::new() - }; + self.jail_buffer.push_str(&new_text); - // Now decode from prefix to current position - let new_full_text = self.tokenizer.decode( - &self.token_buffer[self.prefix_offset..], - self.skip_special_tokens, - )?; - - // Check for incomplete UTF-8 sequence - if new_full_text.ends_with("�") { - // Wait for more tokens to complete the sequence - return Ok(SequenceDecoderOutput::Held); - } - - // Calculate only the NEW text since last successful decode - let new_text = if new_full_text.len() > prefix_text.len() { - &new_full_text[prefix_text.len()..] - } else { - // No new text produced (can happen with special tokens) - return Ok(SequenceDecoderOutput::Held); - }; - - // Combine jail buffer with new text for checking - let check_text = format!("{}{}", self.jail_buffer, new_text); - - // Check for complete stop sequences + // Check for hidden stop sequences for stop_seq in &self.config.stop_sequences { - if let Some(pos) = check_text.find(stop_seq) { + if let Some(pos) = self.jail_buffer.find(stop_seq) { self.stopped = true; - - // Output text before the stop sequence - let output = check_text[..pos].to_string(); + let output = self.jail_buffer[..pos].to_string(); self.jail_buffer.clear(); return Ok(if output.is_empty() { SequenceDecoderOutput::Stopped @@ -177,58 +135,54 @@ impl StopSequenceDecoder { // Check for visible stop sequences for stop_seq in &self.config.visible_stop_sequences { - if let Some(pos) = check_text.find(stop_seq) { + if let Some(pos) = self.jail_buffer.find(stop_seq) { self.stopped = true; - - // Include the stop sequence in output let end_pos = pos + stop_seq.len(); - let output = check_text[..end_pos].to_string(); + let output = self.jail_buffer[..end_pos].to_string(); self.jail_buffer.clear(); return Ok(SequenceDecoderOutput::StoppedWithText(output)); } } - // Check for partial matches at the end of check_text - let mut partial_match_len = 0; + // Check for partial matches: is the end of jail_buffer the start of any stop_seq? + // This handles stop sequences split across tokens + let mut longest_partial = 0; for stop_seq in self .config .stop_sequences .iter() .chain(&self.config.visible_stop_sequences) { - // Check all possible suffixes that could be a prefix of stop_seq - for i in 1..=check_text.len().min(stop_seq.len() - 1) { - let suffix = &check_text[check_text.len() - i..]; + // Check suffixes of jail_buffer that match prefixes of stop_seq + // We check up to stop_seq.len() - 1 to avoid rechecking exact matches + let max_len = self.jail_buffer.len().min(stop_seq.len() - 1); + for len in 1..=max_len { + let suffix = &self.jail_buffer[self.jail_buffer.len() - len..]; if stop_seq.starts_with(suffix) { - partial_match_len = partial_match_len.max(i); + longest_partial = longest_partial.max(len); } } } - if partial_match_len > 0 { - // Split: output safe text, jail the potential match - let safe_end = check_text.len() - partial_match_len; - let safe_text = &check_text[..safe_end]; - self.jail_buffer = check_text[safe_end..].to_string(); + if longest_partial > 0 { + // Hold the partial match, flush the rest + let split_pos = self.jail_buffer.len() - longest_partial; + let to_output = self.jail_buffer[..split_pos].to_string(); + self.jail_buffer = self.jail_buffer[split_pos..].to_string(); - // Update offsets for next iteration - self.prefix_offset = self.read_offset; - self.read_offset = self.token_buffer.len(); - - if safe_text.is_empty() { + if to_output.is_empty() { Ok(SequenceDecoderOutput::Held) } else { - Ok(SequenceDecoderOutput::Text(safe_text.to_string())) + Ok(SequenceDecoderOutput::Text(to_output)) } } else { - // No partial matches - output everything - self.jail_buffer.clear(); - - // Update offsets for next iteration - self.prefix_offset = self.read_offset; - self.read_offset = self.token_buffer.len(); - - Ok(SequenceDecoderOutput::Text(check_text)) + // No partial matches - flush everything + let output = std::mem::take(&mut self.jail_buffer); + if output.is_empty() { + Ok(SequenceDecoderOutput::Held) + } else { + Ok(SequenceDecoderOutput::Text(output)) + } } } @@ -263,9 +217,7 @@ impl StopSequenceDecoder { /// Reset the decoder state pub fn reset(&mut self) { self.jail_buffer.clear(); - self.token_buffer.clear(); - self.prefix_offset = 0; - self.read_offset = 0; + self.sequence.clear(); self.stopped = false; } }