[router] Refactor StopSequenceDecoder to Use Sequence for Incremental Decoding (#11676)
This commit is contained in:
@@ -16,6 +16,9 @@ pub struct Sequence {
|
|||||||
|
|
||||||
/// Current position in the sequence
|
/// Current position in the sequence
|
||||||
read_offset: usize,
|
read_offset: usize,
|
||||||
|
|
||||||
|
/// Whether to skip special tokens when decoding
|
||||||
|
skip_special_tokens: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for Sequence {
|
impl std::fmt::Debug for Sequence {
|
||||||
@@ -45,22 +48,38 @@ impl std::fmt::Debug for Sequence {
|
|||||||
impl Sequence {
|
impl Sequence {
|
||||||
/// Create a new empty sequence
|
/// Create a new empty sequence
|
||||||
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
|
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
|
||||||
|
Self::new_with_options(tokenizer, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a new empty sequence with skip_special_tokens option
|
||||||
|
pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
token_ids: Vec::new(),
|
token_ids: Vec::new(),
|
||||||
prefix_offset: 0,
|
prefix_offset: 0,
|
||||||
read_offset: 0,
|
read_offset: 0,
|
||||||
|
skip_special_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a sequence with initial tokens
|
/// Create a sequence with initial tokens
|
||||||
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
|
||||||
|
Self::with_tokens_and_options(tokenizer, token_ids, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a sequence with initial tokens and skip_special_tokens option
|
||||||
|
pub fn with_tokens_and_options(
|
||||||
|
tokenizer: Arc<dyn TokenizerTrait>,
|
||||||
|
token_ids: Vec<TokenIdType>,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
) -> Self {
|
||||||
let len = token_ids.len();
|
let len = token_ids.len();
|
||||||
Self {
|
Self {
|
||||||
tokenizer,
|
tokenizer,
|
||||||
token_ids,
|
token_ids,
|
||||||
prefix_offset: 0,
|
prefix_offset: 0,
|
||||||
read_offset: len,
|
read_offset: len,
|
||||||
|
skip_special_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,7 +118,9 @@ impl Sequence {
|
|||||||
|
|
||||||
// If this is the first token or we're at the beginning, decode everything
|
// If this is the first token or we're at the beginning, decode everything
|
||||||
if self.prefix_offset == 0 && old_read_offset == 0 {
|
if self.prefix_offset == 0 && old_read_offset == 0 {
|
||||||
let text = self.tokenizer.decode(&self.token_ids, false)?;
|
let text = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&self.token_ids, self.skip_special_tokens)?;
|
||||||
if text.ends_with("<EFBFBD>") {
|
if text.ends_with("<EFBFBD>") {
|
||||||
// Incomplete UTF-8 sequence, wait for more tokens
|
// Incomplete UTF-8 sequence, wait for more tokens
|
||||||
return Ok(String::new());
|
return Ok(String::new());
|
||||||
@@ -109,14 +130,16 @@ impl Sequence {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Decode the text up to the previous position
|
// Decode the text up to the previous position
|
||||||
let prefix_text = self
|
let prefix_text = self.tokenizer.decode(
|
||||||
.tokenizer
|
&self.token_ids[self.prefix_offset..old_read_offset],
|
||||||
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?;
|
self.skip_special_tokens,
|
||||||
|
)?;
|
||||||
|
|
||||||
// Decode the text including the new token
|
// Decode the text including the new token
|
||||||
let new_text = self
|
let new_text = self.tokenizer.decode(
|
||||||
.tokenizer
|
&self.token_ids[self.prefix_offset..],
|
||||||
.decode(&self.token_ids[self.prefix_offset..], false)?;
|
self.skip_special_tokens,
|
||||||
|
)?;
|
||||||
|
|
||||||
// Handle multi-byte character boundaries
|
// Handle multi-byte character boundaries
|
||||||
let mut prefix_text_len = prefix_text.len();
|
let mut prefix_text_len = prefix_text.len();
|
||||||
@@ -151,7 +174,8 @@ impl Sequence {
|
|||||||
|
|
||||||
/// Decode the entire sequence to text
|
/// Decode the entire sequence to text
|
||||||
pub fn text(&self) -> Result<String> {
|
pub fn text(&self) -> Result<String> {
|
||||||
self.tokenizer.decode(&self.token_ids, false)
|
self.tokenizer
|
||||||
|
.decode(&self.token_ids, self.skip_special_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the prefix offset
|
/// Get the prefix offset
|
||||||
@@ -163,6 +187,11 @@ impl Sequence {
|
|||||||
pub fn read_offset(&self) -> usize {
|
pub fn read_offset(&self) -> usize {
|
||||||
self.read_offset
|
self.read_offset
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get whether special tokens are skipped during decoding
|
||||||
|
pub fn skip_special_tokens(&self) -> bool {
|
||||||
|
self.skip_special_tokens
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
use super::sequence::Sequence;
|
||||||
use super::traits::{self, TokenIdType};
|
use super::traits::{self, TokenIdType};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
@@ -57,19 +58,13 @@ impl StopSequenceConfig {
|
|||||||
|
|
||||||
/// Decoder that handles stop sequences
|
/// Decoder that handles stop sequences
|
||||||
pub struct StopSequenceDecoder {
|
pub struct StopSequenceDecoder {
|
||||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
/// Sequence for incremental decoding (replaces token_buffer + offsets)
|
||||||
|
sequence: Sequence,
|
||||||
config: StopSequenceConfig,
|
config: StopSequenceConfig,
|
||||||
/// Buffer for partial matches (the "jail")
|
/// Buffer for partial matches (the "jail")
|
||||||
jail_buffer: String,
|
jail_buffer: String,
|
||||||
/// Accumulated tokens
|
|
||||||
token_buffer: Vec<TokenIdType>,
|
|
||||||
/// Offset where the prefix text starts (for context)
|
|
||||||
prefix_offset: usize,
|
|
||||||
/// Offset marking the end of previously decoded text
|
|
||||||
read_offset: usize,
|
|
||||||
/// Whether we've stopped
|
/// Whether we've stopped
|
||||||
stopped: bool,
|
stopped: bool,
|
||||||
skip_special_tokens: bool,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl StopSequenceDecoder {
|
impl StopSequenceDecoder {
|
||||||
@@ -80,14 +75,10 @@ impl StopSequenceDecoder {
|
|||||||
skip_special_tokens: bool,
|
skip_special_tokens: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
StopSequenceDecoder {
|
StopSequenceDecoder {
|
||||||
tokenizer,
|
sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
|
||||||
config,
|
config,
|
||||||
jail_buffer: String::new(),
|
jail_buffer: String::new(),
|
||||||
token_buffer: Vec::new(),
|
|
||||||
prefix_offset: 0,
|
|
||||||
read_offset: 0,
|
|
||||||
stopped: false,
|
stopped: false,
|
||||||
skip_special_tokens,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,57 +106,24 @@ impl StopSequenceDecoder {
|
|||||||
|
|
||||||
// Include jailed text plus the stop token
|
// Include jailed text plus the stop token
|
||||||
let stop_text = self
|
let stop_text = self
|
||||||
.tokenizer
|
.sequence
|
||||||
.decode(&[token_id], self.skip_special_tokens)?;
|
.tokenizer()
|
||||||
|
.decode(&[token_id], self.sequence.skip_special_tokens())?;
|
||||||
let output = format!("{}{}", self.jail_buffer, stop_text);
|
let output = format!("{}{}", self.jail_buffer, stop_text);
|
||||||
self.jail_buffer.clear();
|
self.jail_buffer.clear();
|
||||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add token to buffer
|
// Use Sequence for incremental decoding
|
||||||
self.token_buffer.push(token_id);
|
let new_text = self.sequence.append_token(token_id)?;
|
||||||
|
|
||||||
// Use incremental decoding like DecodeStream
|
self.jail_buffer.push_str(&new_text);
|
||||||
// First decode the previous context (what we've already output)
|
|
||||||
let prefix_text = if self.read_offset > self.prefix_offset {
|
|
||||||
self.tokenizer.decode(
|
|
||||||
&self.token_buffer[self.prefix_offset..self.read_offset],
|
|
||||||
self.skip_special_tokens,
|
|
||||||
)?
|
|
||||||
} else {
|
|
||||||
String::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Now decode from prefix to current position
|
// Check for hidden stop sequences
|
||||||
let new_full_text = self.tokenizer.decode(
|
|
||||||
&self.token_buffer[self.prefix_offset..],
|
|
||||||
self.skip_special_tokens,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Check for incomplete UTF-8 sequence
|
|
||||||
if new_full_text.ends_with("<EFBFBD>") {
|
|
||||||
// Wait for more tokens to complete the sequence
|
|
||||||
return Ok(SequenceDecoderOutput::Held);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate only the NEW text since last successful decode
|
|
||||||
let new_text = if new_full_text.len() > prefix_text.len() {
|
|
||||||
&new_full_text[prefix_text.len()..]
|
|
||||||
} else {
|
|
||||||
// No new text produced (can happen with special tokens)
|
|
||||||
return Ok(SequenceDecoderOutput::Held);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Combine jail buffer with new text for checking
|
|
||||||
let check_text = format!("{}{}", self.jail_buffer, new_text);
|
|
||||||
|
|
||||||
// Check for complete stop sequences
|
|
||||||
for stop_seq in &self.config.stop_sequences {
|
for stop_seq in &self.config.stop_sequences {
|
||||||
if let Some(pos) = check_text.find(stop_seq) {
|
if let Some(pos) = self.jail_buffer.find(stop_seq) {
|
||||||
self.stopped = true;
|
self.stopped = true;
|
||||||
|
let output = self.jail_buffer[..pos].to_string();
|
||||||
// Output text before the stop sequence
|
|
||||||
let output = check_text[..pos].to_string();
|
|
||||||
self.jail_buffer.clear();
|
self.jail_buffer.clear();
|
||||||
return Ok(if output.is_empty() {
|
return Ok(if output.is_empty() {
|
||||||
SequenceDecoderOutput::Stopped
|
SequenceDecoderOutput::Stopped
|
||||||
@@ -177,58 +135,54 @@ impl StopSequenceDecoder {
|
|||||||
|
|
||||||
// Check for visible stop sequences
|
// Check for visible stop sequences
|
||||||
for stop_seq in &self.config.visible_stop_sequences {
|
for stop_seq in &self.config.visible_stop_sequences {
|
||||||
if let Some(pos) = check_text.find(stop_seq) {
|
if let Some(pos) = self.jail_buffer.find(stop_seq) {
|
||||||
self.stopped = true;
|
self.stopped = true;
|
||||||
|
|
||||||
// Include the stop sequence in output
|
|
||||||
let end_pos = pos + stop_seq.len();
|
let end_pos = pos + stop_seq.len();
|
||||||
let output = check_text[..end_pos].to_string();
|
let output = self.jail_buffer[..end_pos].to_string();
|
||||||
self.jail_buffer.clear();
|
self.jail_buffer.clear();
|
||||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for partial matches at the end of check_text
|
// Check for partial matches: is the end of jail_buffer the start of any stop_seq?
|
||||||
let mut partial_match_len = 0;
|
// This handles stop sequences split across tokens
|
||||||
|
let mut longest_partial = 0;
|
||||||
for stop_seq in self
|
for stop_seq in self
|
||||||
.config
|
.config
|
||||||
.stop_sequences
|
.stop_sequences
|
||||||
.iter()
|
.iter()
|
||||||
.chain(&self.config.visible_stop_sequences)
|
.chain(&self.config.visible_stop_sequences)
|
||||||
{
|
{
|
||||||
// Check all possible suffixes that could be a prefix of stop_seq
|
// Check suffixes of jail_buffer that match prefixes of stop_seq
|
||||||
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
|
// We check up to stop_seq.len() - 1 to avoid rechecking exact matches
|
||||||
let suffix = &check_text[check_text.len() - i..];
|
let max_len = self.jail_buffer.len().min(stop_seq.len() - 1);
|
||||||
|
for len in 1..=max_len {
|
||||||
|
let suffix = &self.jail_buffer[self.jail_buffer.len() - len..];
|
||||||
if stop_seq.starts_with(suffix) {
|
if stop_seq.starts_with(suffix) {
|
||||||
partial_match_len = partial_match_len.max(i);
|
longest_partial = longest_partial.max(len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if partial_match_len > 0 {
|
if longest_partial > 0 {
|
||||||
// Split: output safe text, jail the potential match
|
// Hold the partial match, flush the rest
|
||||||
let safe_end = check_text.len() - partial_match_len;
|
let split_pos = self.jail_buffer.len() - longest_partial;
|
||||||
let safe_text = &check_text[..safe_end];
|
let to_output = self.jail_buffer[..split_pos].to_string();
|
||||||
self.jail_buffer = check_text[safe_end..].to_string();
|
self.jail_buffer = self.jail_buffer[split_pos..].to_string();
|
||||||
|
|
||||||
// Update offsets for next iteration
|
if to_output.is_empty() {
|
||||||
self.prefix_offset = self.read_offset;
|
|
||||||
self.read_offset = self.token_buffer.len();
|
|
||||||
|
|
||||||
if safe_text.is_empty() {
|
|
||||||
Ok(SequenceDecoderOutput::Held)
|
Ok(SequenceDecoderOutput::Held)
|
||||||
} else {
|
} else {
|
||||||
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
|
Ok(SequenceDecoderOutput::Text(to_output))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No partial matches - output everything
|
// No partial matches - flush everything
|
||||||
self.jail_buffer.clear();
|
let output = std::mem::take(&mut self.jail_buffer);
|
||||||
|
if output.is_empty() {
|
||||||
// Update offsets for next iteration
|
Ok(SequenceDecoderOutput::Held)
|
||||||
self.prefix_offset = self.read_offset;
|
} else {
|
||||||
self.read_offset = self.token_buffer.len();
|
Ok(SequenceDecoderOutput::Text(output))
|
||||||
|
}
|
||||||
Ok(SequenceDecoderOutput::Text(check_text))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,9 +217,7 @@ impl StopSequenceDecoder {
|
|||||||
/// Reset the decoder state
|
/// Reset the decoder state
|
||||||
pub fn reset(&mut self) {
|
pub fn reset(&mut self) {
|
||||||
self.jail_buffer.clear();
|
self.jail_buffer.clear();
|
||||||
self.token_buffer.clear();
|
self.sequence.clear();
|
||||||
self.prefix_offset = 0;
|
|
||||||
self.read_offset = 0;
|
|
||||||
self.stopped = false;
|
self.stopped = false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user