[router] Refactor StopSequenceDecoder to Use Sequence for Incremental Decoding (#11676)
This commit is contained in:
@@ -16,6 +16,9 @@ pub struct Sequence {
|
||||
|
||||
/// Current position in the sequence
|
||||
read_offset: usize,
|
||||
|
||||
/// Whether to skip special tokens when decoding
|
||||
skip_special_tokens: bool,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Sequence {
|
||||
@@ -45,22 +48,38 @@ impl std::fmt::Debug for Sequence {
|
||||
impl Sequence {
|
||||
/// Create a new empty sequence
|
||||
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
|
||||
Self::new_with_options(tokenizer, false)
|
||||
}
|
||||
|
||||
/// Create a new empty sequence with skip_special_tokens option
|
||||
pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
|
||||
Self {
|
||||
tokenizer,
|
||||
token_ids: Vec::new(),
|
||||
prefix_offset: 0,
|
||||
read_offset: 0,
|
||||
skip_special_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a sequence with initial tokens
|
||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
||||
Self::with_tokens_and_options(tokenizer, token_ids, false)
|
||||
}
|
||||
|
||||
/// Create a sequence with initial tokens and skip_special_tokens option
|
||||
pub fn with_tokens_and_options(
|
||||
tokenizer: Arc<dyn TokenizerTrait>,
|
||||
token_ids: Vec<TokenIdType>,
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
let len = token_ids.len();
|
||||
Self {
|
||||
tokenizer,
|
||||
token_ids,
|
||||
prefix_offset: 0,
|
||||
read_offset: len,
|
||||
skip_special_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -99,7 +118,9 @@ impl Sequence {
|
||||
|
||||
// If this is the first token or we're at the beginning, decode everything
|
||||
if self.prefix_offset == 0 && old_read_offset == 0 {
|
||||
let text = self.tokenizer.decode(&self.token_ids, false)?;
|
||||
let text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids, self.skip_special_tokens)?;
|
||||
if text.ends_with("<EFBFBD>") {
|
||||
// Incomplete UTF-8 sequence, wait for more tokens
|
||||
return Ok(String::new());
|
||||
@@ -109,14 +130,16 @@ impl Sequence {
|
||||
}
|
||||
|
||||
// Decode the text up to the previous position
|
||||
let prefix_text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?;
|
||||
let prefix_text = self.tokenizer.decode(
|
||||
&self.token_ids[self.prefix_offset..old_read_offset],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
// Decode the text including the new token
|
||||
let new_text = self
|
||||
.tokenizer
|
||||
.decode(&self.token_ids[self.prefix_offset..], false)?;
|
||||
let new_text = self.tokenizer.decode(
|
||||
&self.token_ids[self.prefix_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
// Handle multi-byte character boundaries
|
||||
let mut prefix_text_len = prefix_text.len();
|
||||
@@ -151,7 +174,8 @@ impl Sequence {
|
||||
|
||||
/// Decode the entire sequence to text
|
||||
pub fn text(&self) -> Result<String> {
|
||||
self.tokenizer.decode(&self.token_ids, false)
|
||||
self.tokenizer
|
||||
.decode(&self.token_ids, self.skip_special_tokens)
|
||||
}
|
||||
|
||||
/// Get the prefix offset
|
||||
@@ -163,6 +187,11 @@ impl Sequence {
|
||||
pub fn read_offset(&self) -> usize {
|
||||
self.read_offset
|
||||
}
|
||||
|
||||
/// Get whether special tokens are skipped during decoding
|
||||
pub fn skip_special_tokens(&self) -> bool {
|
||||
self.skip_special_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use super::sequence::Sequence;
|
||||
use super::traits::{self, TokenIdType};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
@@ -57,19 +58,13 @@ impl StopSequenceConfig {
|
||||
|
||||
/// Decoder that handles stop sequences
|
||||
pub struct StopSequenceDecoder {
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
/// Sequence for incremental decoding (replaces token_buffer + offsets)
|
||||
sequence: Sequence,
|
||||
config: StopSequenceConfig,
|
||||
/// Buffer for partial matches (the "jail")
|
||||
jail_buffer: String,
|
||||
/// Accumulated tokens
|
||||
token_buffer: Vec<TokenIdType>,
|
||||
/// Offset where the prefix text starts (for context)
|
||||
prefix_offset: usize,
|
||||
/// Offset marking the end of previously decoded text
|
||||
read_offset: usize,
|
||||
/// Whether we've stopped
|
||||
stopped: bool,
|
||||
skip_special_tokens: bool,
|
||||
}
|
||||
|
||||
impl StopSequenceDecoder {
|
||||
@@ -80,14 +75,10 @@ impl StopSequenceDecoder {
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
StopSequenceDecoder {
|
||||
tokenizer,
|
||||
sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
|
||||
config,
|
||||
jail_buffer: String::new(),
|
||||
token_buffer: Vec::new(),
|
||||
prefix_offset: 0,
|
||||
read_offset: 0,
|
||||
stopped: false,
|
||||
skip_special_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,57 +106,24 @@ impl StopSequenceDecoder {
|
||||
|
||||
// Include jailed text plus the stop token
|
||||
let stop_text = self
|
||||
.tokenizer
|
||||
.decode(&[token_id], self.skip_special_tokens)?;
|
||||
.sequence
|
||||
.tokenizer()
|
||||
.decode(&[token_id], self.sequence.skip_special_tokens())?;
|
||||
let output = format!("{}{}", self.jail_buffer, stop_text);
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
|
||||
// Add token to buffer
|
||||
self.token_buffer.push(token_id);
|
||||
// Use Sequence for incremental decoding
|
||||
let new_text = self.sequence.append_token(token_id)?;
|
||||
|
||||
// Use incremental decoding like DecodeStream
|
||||
// First decode the previous context (what we've already output)
|
||||
let prefix_text = if self.read_offset > self.prefix_offset {
|
||||
self.tokenizer.decode(
|
||||
&self.token_buffer[self.prefix_offset..self.read_offset],
|
||||
self.skip_special_tokens,
|
||||
)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
self.jail_buffer.push_str(&new_text);
|
||||
|
||||
// Now decode from prefix to current position
|
||||
let new_full_text = self.tokenizer.decode(
|
||||
&self.token_buffer[self.prefix_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
// Check for incomplete UTF-8 sequence
|
||||
if new_full_text.ends_with("<EFBFBD>") {
|
||||
// Wait for more tokens to complete the sequence
|
||||
return Ok(SequenceDecoderOutput::Held);
|
||||
}
|
||||
|
||||
// Calculate only the NEW text since last successful decode
|
||||
let new_text = if new_full_text.len() > prefix_text.len() {
|
||||
&new_full_text[prefix_text.len()..]
|
||||
} else {
|
||||
// No new text produced (can happen with special tokens)
|
||||
return Ok(SequenceDecoderOutput::Held);
|
||||
};
|
||||
|
||||
// Combine jail buffer with new text for checking
|
||||
let check_text = format!("{}{}", self.jail_buffer, new_text);
|
||||
|
||||
// Check for complete stop sequences
|
||||
// Check for hidden stop sequences
|
||||
for stop_seq in &self.config.stop_sequences {
|
||||
if let Some(pos) = check_text.find(stop_seq) {
|
||||
if let Some(pos) = self.jail_buffer.find(stop_seq) {
|
||||
self.stopped = true;
|
||||
|
||||
// Output text before the stop sequence
|
||||
let output = check_text[..pos].to_string();
|
||||
let output = self.jail_buffer[..pos].to_string();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(if output.is_empty() {
|
||||
SequenceDecoderOutput::Stopped
|
||||
@@ -177,58 +135,54 @@ impl StopSequenceDecoder {
|
||||
|
||||
// Check for visible stop sequences
|
||||
for stop_seq in &self.config.visible_stop_sequences {
|
||||
if let Some(pos) = check_text.find(stop_seq) {
|
||||
if let Some(pos) = self.jail_buffer.find(stop_seq) {
|
||||
self.stopped = true;
|
||||
|
||||
// Include the stop sequence in output
|
||||
let end_pos = pos + stop_seq.len();
|
||||
let output = check_text[..end_pos].to_string();
|
||||
let output = self.jail_buffer[..end_pos].to_string();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial matches at the end of check_text
|
||||
let mut partial_match_len = 0;
|
||||
// Check for partial matches: is the end of jail_buffer the start of any stop_seq?
|
||||
// This handles stop sequences split across tokens
|
||||
let mut longest_partial = 0;
|
||||
for stop_seq in self
|
||||
.config
|
||||
.stop_sequences
|
||||
.iter()
|
||||
.chain(&self.config.visible_stop_sequences)
|
||||
{
|
||||
// Check all possible suffixes that could be a prefix of stop_seq
|
||||
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
|
||||
let suffix = &check_text[check_text.len() - i..];
|
||||
// Check suffixes of jail_buffer that match prefixes of stop_seq
|
||||
// We check up to stop_seq.len() - 1 to avoid rechecking exact matches
|
||||
let max_len = self.jail_buffer.len().min(stop_seq.len() - 1);
|
||||
for len in 1..=max_len {
|
||||
let suffix = &self.jail_buffer[self.jail_buffer.len() - len..];
|
||||
if stop_seq.starts_with(suffix) {
|
||||
partial_match_len = partial_match_len.max(i);
|
||||
longest_partial = longest_partial.max(len);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if partial_match_len > 0 {
|
||||
// Split: output safe text, jail the potential match
|
||||
let safe_end = check_text.len() - partial_match_len;
|
||||
let safe_text = &check_text[..safe_end];
|
||||
self.jail_buffer = check_text[safe_end..].to_string();
|
||||
if longest_partial > 0 {
|
||||
// Hold the partial match, flush the rest
|
||||
let split_pos = self.jail_buffer.len() - longest_partial;
|
||||
let to_output = self.jail_buffer[..split_pos].to_string();
|
||||
self.jail_buffer = self.jail_buffer[split_pos..].to_string();
|
||||
|
||||
// Update offsets for next iteration
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_buffer.len();
|
||||
|
||||
if safe_text.is_empty() {
|
||||
if to_output.is_empty() {
|
||||
Ok(SequenceDecoderOutput::Held)
|
||||
} else {
|
||||
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
|
||||
Ok(SequenceDecoderOutput::Text(to_output))
|
||||
}
|
||||
} else {
|
||||
// No partial matches - output everything
|
||||
self.jail_buffer.clear();
|
||||
|
||||
// Update offsets for next iteration
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_buffer.len();
|
||||
|
||||
Ok(SequenceDecoderOutput::Text(check_text))
|
||||
// No partial matches - flush everything
|
||||
let output = std::mem::take(&mut self.jail_buffer);
|
||||
if output.is_empty() {
|
||||
Ok(SequenceDecoderOutput::Held)
|
||||
} else {
|
||||
Ok(SequenceDecoderOutput::Text(output))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -263,9 +217,7 @@ impl StopSequenceDecoder {
|
||||
/// Reset the decoder state
|
||||
pub fn reset(&mut self) {
|
||||
self.jail_buffer.clear();
|
||||
self.token_buffer.clear();
|
||||
self.prefix_offset = 0;
|
||||
self.read_offset = 0;
|
||||
self.sequence.clear();
|
||||
self.stopped = false;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user