607 lines
21 KiB
Rust
607 lines
21 KiB
Rust
use std::{collections::HashSet, sync::Arc};
|
||
|
||
use anyhow::Result;
|
||
|
||
use super::{
|
||
sequence::Sequence,
|
||
traits::{self, TokenIdType},
|
||
};
|
||
|
||
/// Output from the sequence decoder
|
||
#[derive(Debug, Clone, PartialEq)]
|
||
pub enum SequenceDecoderOutput {
|
||
/// Normal text output
|
||
Text(String),
|
||
/// Text is being held due to partial stop sequence match
|
||
Held,
|
||
/// Stop sequence matched (hidden - not included in output)
|
||
Stopped,
|
||
/// Stop sequence matched with text (visible - included in output)
|
||
StoppedWithText(String),
|
||
}
|
||
|
||
/// Configuration for stop sequences
|
||
#[derive(Debug, Clone, Default)]
|
||
pub struct StopSequenceConfig {
|
||
/// Token IDs that trigger a stop
|
||
pub stop_tokens: HashSet<TokenIdType>,
|
||
/// String sequences that trigger a stop
|
||
pub stop_sequences: Vec<String>,
|
||
/// Token IDs for visible stops (included in output)
|
||
pub visible_stop_tokens: HashSet<TokenIdType>,
|
||
/// String sequences for visible stops (included in output)
|
||
pub visible_stop_sequences: Vec<String>,
|
||
}
|
||
|
||
impl StopSequenceConfig {
|
||
/// Builder pattern - add a stop token
|
||
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||
self.stop_tokens.insert(token_id);
|
||
self
|
||
}
|
||
|
||
/// Builder pattern - add a stop sequence
|
||
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||
self.stop_sequences.push(sequence.into());
|
||
self
|
||
}
|
||
|
||
/// Builder pattern - add a visible stop token
|
||
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||
self.visible_stop_tokens.insert(token_id);
|
||
self
|
||
}
|
||
|
||
/// Builder pattern - add a visible stop sequence
|
||
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||
self.visible_stop_sequences.push(sequence.into());
|
||
self
|
||
}
|
||
}
|
||
|
||
/// Decoder that handles stop sequences
|
||
pub struct StopSequenceDecoder {
|
||
/// Sequence for incremental decoding (replaces token_buffer + offsets)
|
||
sequence: Sequence,
|
||
config: StopSequenceConfig,
|
||
/// Buffer for partial matches (the "jail")
|
||
jail_buffer: String,
|
||
/// Whether we've stopped
|
||
stopped: bool,
|
||
}
|
||
|
||
impl StopSequenceDecoder {
|
||
/// Create a new stop sequence decoder
|
||
pub fn new(
|
||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||
config: StopSequenceConfig,
|
||
skip_special_tokens: bool,
|
||
) -> Self {
|
||
StopSequenceDecoder {
|
||
sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
|
||
config,
|
||
jail_buffer: String::new(),
|
||
stopped: false,
|
||
}
|
||
}
|
||
|
||
/// Process a single token
|
||
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
|
||
if self.stopped {
|
||
return Ok(SequenceDecoderOutput::Stopped);
|
||
}
|
||
|
||
// Check for token-level stops first
|
||
if self.config.stop_tokens.contains(&token_id) {
|
||
self.stopped = true;
|
||
|
||
// Flush any jailed text before stopping
|
||
if !self.jail_buffer.is_empty() {
|
||
let output = self.jail_buffer.clone();
|
||
self.jail_buffer.clear();
|
||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||
}
|
||
return Ok(SequenceDecoderOutput::Stopped);
|
||
}
|
||
|
||
if self.config.visible_stop_tokens.contains(&token_id) {
|
||
self.stopped = true;
|
||
|
||
// Include jailed text plus the stop token
|
||
let stop_text = self
|
||
.sequence
|
||
.tokenizer()
|
||
.decode(&[token_id], self.sequence.skip_special_tokens())?;
|
||
let output = format!("{}{}", self.jail_buffer, stop_text);
|
||
self.jail_buffer.clear();
|
||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||
}
|
||
|
||
// Use Sequence for incremental decoding
|
||
let new_text = self.sequence.append_token(token_id)?;
|
||
|
||
self.jail_buffer.push_str(&new_text);
|
||
|
||
// Check for hidden stop sequences
|
||
for stop_seq in &self.config.stop_sequences {
|
||
if let Some(pos) = self.jail_buffer.find(stop_seq) {
|
||
self.stopped = true;
|
||
let output = self.jail_buffer[..pos].to_string();
|
||
self.jail_buffer.clear();
|
||
return Ok(if output.is_empty() {
|
||
SequenceDecoderOutput::Stopped
|
||
} else {
|
||
SequenceDecoderOutput::StoppedWithText(output)
|
||
});
|
||
}
|
||
}
|
||
|
||
// Check for visible stop sequences
|
||
for stop_seq in &self.config.visible_stop_sequences {
|
||
if let Some(pos) = self.jail_buffer.find(stop_seq) {
|
||
self.stopped = true;
|
||
let end_pos = pos + stop_seq.len();
|
||
let output = self.jail_buffer[..end_pos].to_string();
|
||
self.jail_buffer.clear();
|
||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||
}
|
||
}
|
||
|
||
// Check for partial matches: is the end of jail_buffer the start of any stop_seq?
|
||
// This handles stop sequences split across tokens
|
||
let buffer_len = self.jail_buffer.len();
|
||
let mut best_split_pos: Option<usize> = None;
|
||
|
||
for stop_seq in self
|
||
.config
|
||
.stop_sequences
|
||
.iter()
|
||
.chain(&self.config.visible_stop_sequences)
|
||
{
|
||
let stop_len = stop_seq.len();
|
||
|
||
if stop_len <= 1 || buffer_len == 0 {
|
||
continue;
|
||
}
|
||
|
||
let max_len = buffer_len.min(stop_len - 1);
|
||
|
||
for len in (1..=max_len).rev() {
|
||
let suffix_start = buffer_len - len;
|
||
|
||
if !self.jail_buffer.is_char_boundary(suffix_start) {
|
||
continue;
|
||
}
|
||
|
||
let suffix = &self.jail_buffer[suffix_start..];
|
||
|
||
if stop_seq.starts_with(suffix)
|
||
&& best_split_pos.is_none_or(|current| suffix_start < current)
|
||
{
|
||
best_split_pos = Some(suffix_start);
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
|
||
if let Some(split_pos) = best_split_pos {
|
||
// Hold the partial match, flush the rest
|
||
// Drain [0..split_pos] as output, keep [split_pos..] in jail_buffer
|
||
let to_output = self.jail_buffer.drain(..split_pos).collect::<String>();
|
||
|
||
if to_output.is_empty() {
|
||
Ok(SequenceDecoderOutput::Held)
|
||
} else {
|
||
Ok(SequenceDecoderOutput::Text(to_output))
|
||
}
|
||
} else {
|
||
// No partial matches - flush everything
|
||
let output = std::mem::take(&mut self.jail_buffer);
|
||
if output.is_empty() {
|
||
Ok(SequenceDecoderOutput::Held)
|
||
} else {
|
||
Ok(SequenceDecoderOutput::Text(output))
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Process multiple tokens
|
||
pub fn process_tokens(
|
||
&mut self,
|
||
token_ids: &[TokenIdType],
|
||
) -> Result<Vec<SequenceDecoderOutput>> {
|
||
let mut outputs = Vec::new();
|
||
for &token_id in token_ids {
|
||
outputs.push(self.process_token(token_id)?);
|
||
}
|
||
Ok(outputs)
|
||
}
|
||
|
||
/// Flush any held text
|
||
pub fn flush(&mut self) -> SequenceDecoderOutput {
|
||
if !self.jail_buffer.is_empty() {
|
||
let output = self.jail_buffer.clone();
|
||
self.jail_buffer.clear();
|
||
SequenceDecoderOutput::Text(output)
|
||
} else {
|
||
SequenceDecoderOutput::Text(String::new())
|
||
}
|
||
}
|
||
|
||
/// Check if decoding has stopped
|
||
pub fn is_stopped(&self) -> bool {
|
||
self.stopped
|
||
}
|
||
|
||
/// Reset the decoder state
|
||
pub fn reset(&mut self) {
|
||
self.jail_buffer.clear();
|
||
self.sequence.clear();
|
||
self.stopped = false;
|
||
}
|
||
}
|
||
|
||
/// Builder for StopSequenceDecoder
|
||
pub struct StopSequenceDecoderBuilder {
|
||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||
config: StopSequenceConfig,
|
||
skip_special_tokens: bool,
|
||
}
|
||
|
||
impl StopSequenceDecoderBuilder {
|
||
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||
StopSequenceDecoderBuilder {
|
||
tokenizer,
|
||
config: StopSequenceConfig::default(),
|
||
skip_special_tokens: true,
|
||
}
|
||
}
|
||
|
||
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
|
||
self.config.stop_tokens.insert(token_id);
|
||
self
|
||
}
|
||
|
||
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||
self.config.stop_sequences.push(sequence.into());
|
||
self
|
||
}
|
||
|
||
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
|
||
self.config.visible_stop_tokens.insert(token_id);
|
||
self
|
||
}
|
||
|
||
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||
self.config.visible_stop_sequences.push(sequence.into());
|
||
self
|
||
}
|
||
|
||
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
|
||
self.skip_special_tokens = skip;
|
||
self
|
||
}
|
||
|
||
pub fn build(self) -> StopSequenceDecoder {
|
||
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::tokenizer::mock::MockTokenizer;
|
||
|
||
#[test]
|
||
fn test_stop_token_detection() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process tokens before stop
|
||
let result = decoder.process_token(1).unwrap(); // "Hello"
|
||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||
|
||
// Process stop token
|
||
let result = decoder.process_token(999).unwrap(); // <eos>
|
||
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||
|
||
// Further tokens should also return Stopped
|
||
let result = decoder.process_token(2).unwrap();
|
||
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||
}
|
||
|
||
#[test]
|
||
fn test_visible_stop_token() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_visible_stop_token(999);
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
let result = decoder.process_token(999).unwrap();
|
||
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
|
||
}
|
||
|
||
#[test]
|
||
fn test_builder_pattern() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
|
||
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
|
||
.stop_token(999)
|
||
.stop_sequence("STOP")
|
||
.visible_stop_token(1000)
|
||
.skip_special_tokens(true)
|
||
.build();
|
||
|
||
assert!(!decoder.is_stopped());
|
||
}
|
||
|
||
#[test]
|
||
fn test_incremental_decoding_no_repetition() {
|
||
// This test verifies the critical fix: no repeated output
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default();
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process tokens one by one and collect outputs
|
||
let mut outputs = Vec::new();
|
||
|
||
// Token 1: "Hello"
|
||
let result = decoder.process_token(1).unwrap();
|
||
if let SequenceDecoderOutput::Text(text) = result {
|
||
outputs.push(text.clone());
|
||
}
|
||
|
||
// Token 2: "world"
|
||
let result = decoder.process_token(2).unwrap();
|
||
if let SequenceDecoderOutput::Text(text) = result {
|
||
outputs.push(text.clone());
|
||
}
|
||
|
||
// Token 3: "test"
|
||
let result = decoder.process_token(3).unwrap();
|
||
if let SequenceDecoderOutput::Text(text) = result {
|
||
outputs.push(text.clone());
|
||
}
|
||
|
||
// CRITICAL: Each output should be unique (no accumulation)
|
||
// The fix ensures we only output NEW text, not accumulated text
|
||
assert_eq!(outputs.len(), 3);
|
||
|
||
for i in 0..outputs.len() {
|
||
for j in i + 1..outputs.len() {
|
||
// No output should contain another (no accumulation)
|
||
assert!(!outputs[j].contains(&outputs[i]));
|
||
}
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_stop_sequence_detection() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence("test");
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process "Hello world"
|
||
decoder.process_token(1).unwrap(); // "Hello"
|
||
decoder.process_token(2).unwrap(); // "world"
|
||
|
||
// Process "test" which should trigger stop
|
||
let result = decoder.process_token(3).unwrap(); // "test"
|
||
|
||
// Should stop when we hit "test"
|
||
assert!(matches!(
|
||
result,
|
||
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
|
||
));
|
||
}
|
||
|
||
#[test]
|
||
fn test_flush_after_partial() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process a token
|
||
decoder.process_token(1).unwrap(); // "Hello"
|
||
|
||
// Flush should return any remaining text in jail
|
||
let result = decoder.flush();
|
||
|
||
// After processing, flush should work
|
||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||
}
|
||
|
||
#[test]
|
||
fn test_reset_functionality() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_token(999);
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process and stop
|
||
decoder.process_token(1).unwrap();
|
||
decoder.process_token(999).unwrap();
|
||
assert!(decoder.is_stopped());
|
||
|
||
// Reset should clear everything
|
||
decoder.reset();
|
||
assert!(!decoder.is_stopped());
|
||
|
||
// Should be able to process again
|
||
let result = decoder.process_token(2).unwrap();
|
||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||
}
|
||
|
||
#[test]
|
||
fn test_visible_stop_sequence() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process "Hello"
|
||
decoder.process_token(1).unwrap();
|
||
|
||
// Process "world" - should include it in output
|
||
let result = decoder.process_token(2).unwrap();
|
||
|
||
if let SequenceDecoderOutput::StoppedWithText(text) = result {
|
||
// Should include "world" in the output
|
||
assert!(text.contains("world"));
|
||
} else {
|
||
panic!("Expected StoppedWithText with visible stop sequence");
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_multiple_tokens_processing() {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default();
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process multiple tokens at once
|
||
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
|
||
|
||
// Should get results for each token
|
||
assert_eq!(results.len(), 3);
|
||
|
||
// Each result should be Text (no stops configured)
|
||
for result in results {
|
||
assert!(matches!(
|
||
result,
|
||
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
|
||
));
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_utf8_multibyte_character_boundaries() {
|
||
// This test verifies the fix for the UTF-8 boundary panic
|
||
// The panic occurred when trying to slice jail_buffer at a byte index
|
||
// that was in the middle of a multi-byte UTF-8 character (e.g., '×')
|
||
use crate::tokenizer::mock::MockTokenizer;
|
||
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
|
||
// Configure stop sequence with a multi-byte character
|
||
let config = StopSequenceConfig::default().with_stop_sequence(" ×");
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Simulate the scenario: jail_buffer will contain " ×" (space + multiplication sign)
|
||
// The '×' character is UTF-8 encoded as bytes [0xC3, 0x97] (2 bytes)
|
||
// When checking for partial matches, we must not slice in the middle of these bytes
|
||
|
||
// This should not panic - the fix ensures we only slice at char boundaries
|
||
let result = decoder.process_token(1); // Will add some text to jail_buffer
|
||
assert!(result.is_ok());
|
||
|
||
// Even with multi-byte UTF-8 characters in the buffer, processing should work
|
||
let result = decoder.process_token(2);
|
||
assert!(result.is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_utf8_multibyte_delta_character() {
|
||
// Test for: byte index 1 is not a char boundary; it is inside 'Δ' (bytes 0..2) of `Δ`
|
||
// 'Δ' (U+0394 GREEK CAPITAL LETTER DELTA) is encoded as [0xCE, 0x94] (2 bytes)
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence("Δ");
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process tokens - should not panic when checking partial matches
|
||
let result = decoder.process_token(1);
|
||
assert!(result.is_ok());
|
||
let result = decoder.process_token(2);
|
||
assert!(result.is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_utf8_multibyte_degree_character() {
|
||
// Test for: byte index 1 is not a char boundary; it is inside '°' (bytes 0..2) of `°`
|
||
// '°' (U+00B0 DEGREE SIGN) is encoded as [0xC2, 0xB0] (2 bytes)
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence("°");
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process tokens - should not panic when checking partial matches
|
||
let result = decoder.process_token(1);
|
||
assert!(result.is_ok());
|
||
let result = decoder.process_token(2);
|
||
assert!(result.is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_utf8_multibyte_triangle_character() {
|
||
// Test for: byte index 4 is not a char boundary; it is inside '∆' (bytes 2..5) of ` (∆`
|
||
// '∆' (U+2206 INCREMENT) is encoded as [0xE2, 0x88, 0x86] (3 bytes)
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence(" (∆");
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process tokens - should not panic when checking partial matches
|
||
let result = decoder.process_token(1);
|
||
assert!(result.is_ok());
|
||
let result = decoder.process_token(2);
|
||
assert!(result.is_ok());
|
||
let result = decoder.process_token(3);
|
||
assert!(result.is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_utf8_multibyte_en_dash_character() {
|
||
// Test for: byte index 3 is not a char boundary; it is inside '–' (bytes 1..4) of ` –`
|
||
// '–' (U+2013 EN DASH) is encoded as [0xE2, 0x80, 0x93] (3 bytes)
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence(" –");
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process tokens - should not panic when checking partial matches
|
||
let result = decoder.process_token(1);
|
||
assert!(result.is_ok());
|
||
let result = decoder.process_token(2);
|
||
assert!(result.is_ok());
|
||
let result = decoder.process_token(3);
|
||
assert!(result.is_ok());
|
||
}
|
||
|
||
#[test]
|
||
fn test_utf8_multibyte_various_characters() {
|
||
// Comprehensive test with multiple multi-byte UTF-8 characters
|
||
// Tests 2-byte, 3-byte, and 4-byte UTF-8 sequences
|
||
let test_cases = vec![
|
||
("×", "multiplication sign - 2 bytes"),
|
||
("Δ", "Greek Delta - 2 bytes"),
|
||
("°", "degree sign - 2 bytes"),
|
||
("∆", "increment - 3 bytes"),
|
||
("–", "en dash - 3 bytes"),
|
||
("€", "euro sign - 3 bytes"),
|
||
("中", "Chinese character - 3 bytes"),
|
||
("🚀", "rocket emoji - 4 bytes"),
|
||
("💡", "lightbulb emoji - 4 bytes"),
|
||
];
|
||
|
||
for (stop_char, description) in test_cases {
|
||
let tokenizer = Arc::new(MockTokenizer::new());
|
||
let config = StopSequenceConfig::default().with_stop_sequence(stop_char);
|
||
|
||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||
|
||
// Process multiple tokens - should not panic
|
||
for token_id in 1..=5 {
|
||
let result = decoder.process_token(token_id);
|
||
assert!(
|
||
result.is_ok(),
|
||
"Failed on {} with token {}",
|
||
description,
|
||
token_id
|
||
);
|
||
}
|
||
}
|
||
}
|
||
}
|