diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 71c2e7ccb..2460b635a 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -3,6 +3,10 @@ name = "sglang_router_rs" version = "0.0.0" edition = "2021" +[features] +default = ["huggingface"] +huggingface = ["tokenizers"] + [lib] name = "sglang_router_rs" # Pure Rust library: Just omit crate-type (defaults to rlib) @@ -44,7 +48,7 @@ thiserror = "2.0.12" url = "2.5.4" tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" -tokenizers = "0.21.4" +tokenizers = { version = "0.21.4", optional = true } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs new file mode 100644 index 000000000..6639f35a1 --- /dev/null +++ b/sgl-router/src/tokenizer/factory.rs @@ -0,0 +1,228 @@ +use super::traits; +use anyhow::{Error, Result}; +use std::fs::File; +use std::io::Read; +use std::path::Path; +use std::sync::Arc; + +#[cfg(feature = "huggingface")] +use super::huggingface::HuggingFaceTokenizer; + +/// Represents the type of tokenizer being used +#[derive(Debug, Clone)] +pub enum TokenizerType { + HuggingFace(String), + Mock, + // Future: SentencePiece, GGUF, Tiktoken +} + +/// Create a tokenizer from a file path to a tokenizer file. +/// The file extension is used to determine the tokenizer type. +/// Supported file types are: +/// - json: HuggingFace tokenizer +/// - For testing: can return mock tokenizer +pub fn create_tokenizer_from_file(file_path: &str) -> Result> { + // Special case for testing + if file_path == "mock" || file_path == "test" { + return Ok(Arc::new(super::mock::MockTokenizer::new())); + } + + let path = Path::new(file_path); + + // Check if file exists + if !path.exists() { + return Err(Error::msg(format!("File not found: {}", file_path))); + } + + // Try to determine tokenizer type from extension + let extension = path + .extension() + .and_then(std::ffi::OsStr::to_str) + .map(|s| s.to_lowercase()); + + match extension.as_deref() { + Some("json") => { + #[cfg(feature = "huggingface")] + { + let tokenizer = HuggingFaceTokenizer::from_file(file_path)?; + Ok(Arc::new(tokenizer)) + } + #[cfg(not(feature = "huggingface"))] + { + Err(Error::msg( + "HuggingFace support not enabled. Enable the 'huggingface' feature.", + )) + } + } + Some("model") => { + // SentencePiece model file + Err(Error::msg("SentencePiece models not yet supported")) + } + Some("gguf") => { + // GGUF format + Err(Error::msg("GGUF format not yet supported")) + } + _ => { + // Try to auto-detect by reading file content + auto_detect_tokenizer(file_path) + } + } +} + +/// Auto-detect tokenizer type by examining file content +fn auto_detect_tokenizer(file_path: &str) -> Result> { + let mut file = File::open(file_path)?; + let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection + let bytes_read = file.read(&mut buffer)?; + buffer.truncate(bytes_read); + + // Check for JSON (HuggingFace format) + if is_likely_json(&buffer) { + #[cfg(feature = "huggingface")] + { + let tokenizer = HuggingFaceTokenizer::from_file(file_path)?; + return Ok(Arc::new(tokenizer)); + } + #[cfg(not(feature = "huggingface"))] + { + return Err(Error::msg( + "File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled", + )); + } + } + + // Check for GGUF magic number + if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" { + return Err(Error::msg("GGUF format detected but not yet supported")); + } + + // Check for SentencePiece model + if is_likely_sentencepiece(&buffer) { + return Err(Error::msg( + "SentencePiece model detected but not yet supported", + )); + } + + Err(Error::msg(format!( + "Unable to determine tokenizer type for file: {}", + file_path + ))) +} + +/// Check if the buffer likely contains JSON data +fn is_likely_json(buffer: &[u8]) -> bool { + // Skip UTF-8 BOM if present + let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] { + &buffer[3..] + } else { + buffer + }; + + // Find first non-whitespace character without allocation + if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) { + *first_byte == b'{' || *first_byte == b'[' + } else { + false + } +} + +/// Check if the buffer likely contains a SentencePiece model +fn is_likely_sentencepiece(buffer: &[u8]) -> bool { + // SentencePiece models often start with specific patterns + // This is a simplified check + buffer.len() >= 12 + && (buffer.starts_with(b"\x0a\x09") + || buffer.starts_with(b"\x08\x00") + || buffer.windows(4).any(|w| w == b"") + || buffer.windows(4).any(|w| w == b"")) +} + +/// Factory function to create tokenizer from a model name or path +pub fn create_tokenizer(model_name_or_path: &str) -> Result> { + // Check if it's a file path + let path = Path::new(model_name_or_path); + if path.exists() { + return create_tokenizer_from_file(model_name_or_path); + } + + // Otherwise, try to load from HuggingFace Hub + #[cfg(feature = "huggingface")] + { + // This would download from HF Hub - not implemented yet + Err(Error::msg( + "Loading from HuggingFace Hub not yet implemented", + )) + } + + #[cfg(not(feature = "huggingface"))] + { + Err(Error::msg(format!( + "Model '{}' not found locally and HuggingFace support is not enabled", + model_name_or_path + ))) + } +} + +/// Get information about a tokenizer file +pub fn get_tokenizer_info(file_path: &str) -> Result { + let path = Path::new(file_path); + + if !path.exists() { + return Err(Error::msg(format!("File not found: {}", file_path))); + } + + let extension = path + .extension() + .and_then(std::ffi::OsStr::to_str) + .map(|s| s.to_lowercase()); + + match extension.as_deref() { + Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())), + _ => { + // Try auto-detection + use std::fs::File; + use std::io::Read; + + let mut file = File::open(file_path)?; + let mut buffer = vec![0u8; 512]; + let bytes_read = file.read(&mut buffer)?; + buffer.truncate(bytes_read); + + if is_likely_json(&buffer) { + Ok(TokenizerType::HuggingFace(file_path.to_string())) + } else { + Err(Error::msg("Unknown tokenizer type")) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_json_detection() { + assert!(is_likely_json(b"{\"test\": \"value\"}")); + assert!(is_likely_json(b" \n\t{\"test\": \"value\"}")); + assert!(is_likely_json(b"[1, 2, 3]")); + assert!(!is_likely_json(b"not json")); + assert!(!is_likely_json(b"")); + } + + #[test] + fn test_mock_tokenizer_creation() { + let tokenizer = create_tokenizer_from_file("mock").unwrap(); + assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens + } + + #[test] + fn test_file_not_found() { + let result = create_tokenizer_from_file("/nonexistent/file.json"); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("File not found")); + } + } +} diff --git a/sgl-router/src/tokenizer/huggingface.rs b/sgl-router/src/tokenizer/huggingface.rs new file mode 100644 index 000000000..70eabfc4a --- /dev/null +++ b/sgl-router/src/tokenizer/huggingface.rs @@ -0,0 +1,189 @@ +use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; +use anyhow::{Error, Result}; +use std::collections::HashMap; +use tokenizers::tokenizer::Tokenizer as HfTokenizer; + +/// HuggingFace tokenizer wrapper +pub struct HuggingFaceTokenizer { + tokenizer: HfTokenizer, + special_tokens: SpecialTokens, + vocab: HashMap, + reverse_vocab: HashMap, +} + +impl HuggingFaceTokenizer { + /// Create a tokenizer from a HuggingFace tokenizer JSON file + pub fn from_file(file_path: &str) -> Result { + let tokenizer = HfTokenizer::from_file(file_path) + .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?; + + // Extract special tokens + let special_tokens = Self::extract_special_tokens(&tokenizer); + + // Build vocab mappings + let vocab = tokenizer.get_vocab(false); + let reverse_vocab: HashMap = vocab + .iter() + .map(|(token, &id)| (id, token.clone())) + .collect(); + + Ok(HuggingFaceTokenizer { + tokenizer, + special_tokens, + vocab, + reverse_vocab, + }) + } + + /// Create from an existing HuggingFace tokenizer + pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self { + let special_tokens = Self::extract_special_tokens(&tokenizer); + let vocab = tokenizer.get_vocab(false); + let reverse_vocab: HashMap = vocab + .iter() + .map(|(token, &id)| (id, token.clone())) + .collect(); + + HuggingFaceTokenizer { + tokenizer, + special_tokens, + vocab, + reverse_vocab, + } + } + + /// Extract special tokens from the tokenizer + fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens { + // Try to get special tokens from the tokenizer + // This is a simplified version - actual implementation would need to handle various formats + let vocab = tokenizer.get_vocab(true); + + let find_token = |patterns: &[&str]| -> Option { + for pattern in patterns { + if vocab.contains_key(*pattern) { + return Some(pattern.to_string()); + } + } + None + }; + + SpecialTokens { + bos_token: find_token(&["", "<|startoftext|>", "", "[CLS]"]), + eos_token: find_token(&["", "<|endoftext|>", "", "[SEP]"]), + unk_token: find_token(&["", "", "[UNK]"]), + sep_token: find_token(&["[SEP]", "", ""]), + pad_token: find_token(&["", "", "[PAD]"]), + cls_token: find_token(&["[CLS]", "", ""]), + mask_token: find_token(&["[MASK]", "", ""]), + additional_special_tokens: vec![], + } + } + + /// Apply chat template if available + pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result { + // This is a placeholder - actual implementation would handle templates + let mut result = String::new(); + for msg in messages { + result.push_str(&format!("{}: {}\n", msg.role, msg.content)); + } + Ok(result) + } +} + +impl Encoder for HuggingFaceTokenizer { + fn encode(&self, input: &str) -> Result { + let encoding = self + .tokenizer + .encode(input, false) + .map_err(|e| Error::msg(format!("Encoding failed: {}", e)))?; + + Ok(Encoding::Hf(Box::new(encoding))) + } + + fn encode_batch(&self, inputs: &[&str]) -> Result> { + let encodings = self + .tokenizer + .encode_batch(inputs.to_vec(), false) + .map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?; + + Ok(encodings + .into_iter() + .map(|e| Encoding::Hf(Box::new(e))) + .collect()) + } +} + +impl Decoder for HuggingFaceTokenizer { + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + self.tokenizer + .decode(token_ids, skip_special_tokens) + .map_err(|e| Error::msg(format!("Decoding failed: {}", e))) + } +} + +impl TokenizerTrait for HuggingFaceTokenizer { + fn vocab_size(&self) -> usize { + self.tokenizer.get_vocab_size(false) + } + + fn get_special_tokens(&self) -> &SpecialTokens { + &self.special_tokens + } + + fn token_to_id(&self, token: &str) -> Option { + self.vocab.get(token).copied() + } + + fn id_to_token(&self, id: u32) -> Option { + self.reverse_vocab.get(&id).cloned() + } +} + +/// Represents a chat message for template application +#[derive(Debug, Clone)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +impl ChatMessage { + pub fn new(role: impl Into, content: impl Into) -> Self { + ChatMessage { + role: role.into(), + content: content.into(), + } + } + + pub fn system(content: impl Into) -> Self { + Self::new("system", content) + } + + pub fn user(content: impl Into) -> Self { + Self::new("user", content) + } + + pub fn assistant(content: impl Into) -> Self { + Self::new("assistant", content) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_chat_message_creation() { + let msg = ChatMessage::system("You are a helpful assistant"); + assert_eq!(msg.role, "system"); + assert_eq!(msg.content, "You are a helpful assistant"); + + let user_msg = ChatMessage::user("Hello!"); + assert_eq!(user_msg.role, "user"); + + let assistant_msg = ChatMessage::assistant("Hi there!"); + assert_eq!(assistant_msg.role, "assistant"); + } + + // Note: Actual tokenizer tests would require a real tokenizer file + // These would be integration tests rather than unit tests +} diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs index a77884abe..c218dbecc 100644 --- a/sgl-router/src/tokenizer/mod.rs +++ b/sgl-router/src/tokenizer/mod.rs @@ -2,26 +2,36 @@ use anyhow::Result; use std::ops::Deref; use std::sync::Arc; +pub mod factory; pub mod mock; +pub mod stop; pub mod stream; pub mod traits; +// Feature-gated modules +#[cfg(feature = "huggingface")] +pub mod huggingface; + #[cfg(test)] mod tests; +// Re-exports +pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType}; +pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stream::DecodeStream; pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; +#[cfg(feature = "huggingface")] +pub use huggingface::{ChatMessage, HuggingFaceTokenizer}; + /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations #[derive(Clone)] pub struct Tokenizer(Arc); impl Tokenizer { /// Create a tokenizer from a file path - /// Will be implemented in Phase 3 with factory pattern - pub fn from_file(_file_path: &str) -> Result { - // TODO: Implement factory pattern in Phase 3 - unimplemented!("Factory pattern will be implemented in Phase 3") + pub fn from_file(file_path: &str) -> Result { + Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?)) } /// Create a tokenizer from an Arc diff --git a/sgl-router/src/tokenizer/stop.rs b/sgl-router/src/tokenizer/stop.rs new file mode 100644 index 000000000..19dd60802 --- /dev/null +++ b/sgl-router/src/tokenizer/stop.rs @@ -0,0 +1,499 @@ +use super::traits; +use anyhow::Result; +use std::collections::HashSet; +use std::sync::Arc; + +/// Output from the sequence decoder +#[derive(Debug, Clone, PartialEq)] +pub enum SequenceDecoderOutput { + /// Normal text output + Text(String), + /// Text is being held due to partial stop sequence match + Held, + /// Stop sequence matched (hidden - not included in output) + Stopped, + /// Stop sequence matched with text (visible - included in output) + StoppedWithText(String), +} + +/// Configuration for stop sequences +#[derive(Debug, Clone, Default)] +pub struct StopSequenceConfig { + /// Token IDs that trigger a stop + pub stop_tokens: HashSet, + /// String sequences that trigger a stop + pub stop_sequences: Vec, + /// Token IDs for visible stops (included in output) + pub visible_stop_tokens: HashSet, + /// String sequences for visible stops (included in output) + pub visible_stop_sequences: Vec, +} + +impl StopSequenceConfig { + /// Builder pattern - add a stop token + pub fn with_stop_token(mut self, token_id: u32) -> Self { + self.stop_tokens.insert(token_id); + self + } + + /// Builder pattern - add a stop sequence + pub fn with_stop_sequence(mut self, sequence: impl Into) -> Self { + self.stop_sequences.push(sequence.into()); + self + } + + /// Builder pattern - add a visible stop token + pub fn with_visible_stop_token(mut self, token_id: u32) -> Self { + self.visible_stop_tokens.insert(token_id); + self + } + + /// Builder pattern - add a visible stop sequence + pub fn with_visible_stop_sequence(mut self, sequence: impl Into) -> Self { + self.visible_stop_sequences.push(sequence.into()); + self + } +} + +/// Decoder that handles stop sequences +pub struct StopSequenceDecoder { + tokenizer: Arc, + config: StopSequenceConfig, + /// Buffer for partial matches (the "jail") + jail_buffer: String, + /// Accumulated tokens + token_buffer: Vec, + /// Offset where the prefix text starts (for context) + prefix_offset: usize, + /// Offset marking the end of previously decoded text + read_offset: usize, + /// Whether we've stopped + stopped: bool, + skip_special_tokens: bool, +} + +impl StopSequenceDecoder { + /// Create a new stop sequence decoder + pub fn new( + tokenizer: Arc, + config: StopSequenceConfig, + skip_special_tokens: bool, + ) -> Self { + StopSequenceDecoder { + tokenizer, + config, + jail_buffer: String::new(), + token_buffer: Vec::new(), + prefix_offset: 0, + read_offset: 0, + stopped: false, + skip_special_tokens, + } + } + + /// Process a single token + pub fn process_token(&mut self, token_id: u32) -> Result { + if self.stopped { + return Ok(SequenceDecoderOutput::Stopped); + } + + // Check for token-level stops first + if self.config.stop_tokens.contains(&token_id) { + self.stopped = true; + // Flush any jailed text before stopping + if !self.jail_buffer.is_empty() { + let output = self.jail_buffer.clone(); + self.jail_buffer.clear(); + return Ok(SequenceDecoderOutput::StoppedWithText(output)); + } + return Ok(SequenceDecoderOutput::Stopped); + } + + if self.config.visible_stop_tokens.contains(&token_id) { + self.stopped = true; + // Include jailed text plus the stop token + let stop_text = self + .tokenizer + .decode(&[token_id], self.skip_special_tokens)?; + let output = format!("{}{}", self.jail_buffer, stop_text); + self.jail_buffer.clear(); + return Ok(SequenceDecoderOutput::StoppedWithText(output)); + } + + // Add token to buffer + self.token_buffer.push(token_id); + + // Use incremental decoding like DecodeStream + // First decode the previous context (what we've already output) + let prefix_text = if self.read_offset > self.prefix_offset { + self.tokenizer.decode( + &self.token_buffer[self.prefix_offset..self.read_offset], + self.skip_special_tokens, + )? + } else { + String::new() + }; + + // Now decode from prefix to current position + let new_full_text = self.tokenizer.decode( + &self.token_buffer[self.prefix_offset..], + self.skip_special_tokens, + )?; + + // Check for incomplete UTF-8 sequence + if new_full_text.ends_with("�") { + // Wait for more tokens to complete the sequence + return Ok(SequenceDecoderOutput::Held); + } + + // Calculate only the NEW text since last successful decode + let new_text = if new_full_text.len() > prefix_text.len() { + &new_full_text[prefix_text.len()..] + } else { + // No new text produced (can happen with special tokens) + return Ok(SequenceDecoderOutput::Held); + }; + + // Combine jail buffer with new text for checking + let check_text = format!("{}{}", self.jail_buffer, new_text); + + // Check for complete stop sequences + for stop_seq in &self.config.stop_sequences { + if let Some(pos) = check_text.find(stop_seq) { + self.stopped = true; + // Output text before the stop sequence + let output = check_text[..pos].to_string(); + self.jail_buffer.clear(); + return Ok(if output.is_empty() { + SequenceDecoderOutput::Stopped + } else { + SequenceDecoderOutput::StoppedWithText(output) + }); + } + } + + // Check for visible stop sequences + for stop_seq in &self.config.visible_stop_sequences { + if let Some(pos) = check_text.find(stop_seq) { + self.stopped = true; + // Include the stop sequence in output + let end_pos = pos + stop_seq.len(); + let output = check_text[..end_pos].to_string(); + self.jail_buffer.clear(); + return Ok(SequenceDecoderOutput::StoppedWithText(output)); + } + } + + // Check for partial matches at the end of check_text + let mut partial_match_len = 0; + for stop_seq in self + .config + .stop_sequences + .iter() + .chain(&self.config.visible_stop_sequences) + { + // Check all possible suffixes that could be a prefix of stop_seq + for i in 1..=check_text.len().min(stop_seq.len() - 1) { + let suffix = &check_text[check_text.len() - i..]; + if stop_seq.starts_with(suffix) { + partial_match_len = partial_match_len.max(i); + } + } + } + + if partial_match_len > 0 { + // Split: output safe text, jail the potential match + let safe_end = check_text.len() - partial_match_len; + let safe_text = &check_text[..safe_end]; + self.jail_buffer = check_text[safe_end..].to_string(); + + // Update offsets for next iteration + self.prefix_offset = self.read_offset; + self.read_offset = self.token_buffer.len(); + + if safe_text.is_empty() { + Ok(SequenceDecoderOutput::Held) + } else { + Ok(SequenceDecoderOutput::Text(safe_text.to_string())) + } + } else { + // No partial matches - output everything + self.jail_buffer.clear(); + + // Update offsets for next iteration + self.prefix_offset = self.read_offset; + self.read_offset = self.token_buffer.len(); + + Ok(SequenceDecoderOutput::Text(check_text)) + } + } + + /// Process multiple tokens + pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result> { + let mut outputs = Vec::new(); + for &token_id in token_ids { + outputs.push(self.process_token(token_id)?); + } + Ok(outputs) + } + + /// Flush any held text + pub fn flush(&mut self) -> SequenceDecoderOutput { + if !self.jail_buffer.is_empty() { + let output = self.jail_buffer.clone(); + self.jail_buffer.clear(); + SequenceDecoderOutput::Text(output) + } else { + SequenceDecoderOutput::Text(String::new()) + } + } + + /// Check if decoding has stopped + pub fn is_stopped(&self) -> bool { + self.stopped + } + + /// Reset the decoder state + pub fn reset(&mut self) { + self.jail_buffer.clear(); + self.token_buffer.clear(); + self.prefix_offset = 0; + self.read_offset = 0; + self.stopped = false; + } +} + +/// Builder for StopSequenceDecoder +pub struct StopSequenceDecoderBuilder { + tokenizer: Arc, + config: StopSequenceConfig, + skip_special_tokens: bool, +} + +impl StopSequenceDecoderBuilder { + pub fn new(tokenizer: Arc) -> Self { + StopSequenceDecoderBuilder { + tokenizer, + config: StopSequenceConfig::default(), + skip_special_tokens: true, + } + } + + pub fn stop_token(mut self, token_id: u32) -> Self { + self.config.stop_tokens.insert(token_id); + self + } + + pub fn stop_sequence(mut self, sequence: impl Into) -> Self { + self.config.stop_sequences.push(sequence.into()); + self + } + + pub fn visible_stop_token(mut self, token_id: u32) -> Self { + self.config.visible_stop_tokens.insert(token_id); + self + } + + pub fn visible_stop_sequence(mut self, sequence: impl Into) -> Self { + self.config.visible_stop_sequences.push(sequence.into()); + self + } + + pub fn skip_special_tokens(mut self, skip: bool) -> Self { + self.skip_special_tokens = skip; + self + } + + pub fn build(self) -> StopSequenceDecoder { + StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokenizer::mock::MockTokenizer; + + #[test] + fn test_stop_token_detection() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default().with_stop_token(999); // token + + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process tokens before stop + let result = decoder.process_token(1).unwrap(); // "Hello" + assert!(matches!(result, SequenceDecoderOutput::Text(_))); + + // Process stop token + let result = decoder.process_token(999).unwrap(); // + assert_eq!(result, SequenceDecoderOutput::Stopped); + + // Further tokens should also return Stopped + let result = decoder.process_token(2).unwrap(); + assert_eq!(result, SequenceDecoderOutput::Stopped); + } + + #[test] + fn test_visible_stop_token() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default().with_visible_stop_token(999); + + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + let result = decoder.process_token(999).unwrap(); + assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_))); + } + + #[test] + fn test_builder_pattern() { + let tokenizer = Arc::new(MockTokenizer::new()); + + let decoder = StopSequenceDecoderBuilder::new(tokenizer) + .stop_token(999) + .stop_sequence("STOP") + .visible_stop_token(1000) + .skip_special_tokens(true) + .build(); + + assert!(!decoder.is_stopped()); + } + + #[test] + fn test_incremental_decoding_no_repetition() { + // This test verifies the critical fix: no repeated output + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default(); + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process tokens one by one and collect outputs + let mut outputs = Vec::new(); + + // Token 1: "Hello" + let result = decoder.process_token(1).unwrap(); + if let SequenceDecoderOutput::Text(text) = result { + outputs.push(text.clone()); + } + + // Token 2: "world" + let result = decoder.process_token(2).unwrap(); + if let SequenceDecoderOutput::Text(text) = result { + outputs.push(text.clone()); + } + + // Token 3: "test" + let result = decoder.process_token(3).unwrap(); + if let SequenceDecoderOutput::Text(text) = result { + outputs.push(text.clone()); + } + + // CRITICAL: Each output should be unique (no accumulation) + // The fix ensures we only output NEW text, not accumulated text + assert_eq!(outputs.len(), 3); + + // Verify no text is repeated + for i in 0..outputs.len() { + for j in i + 1..outputs.len() { + // No output should contain another (no accumulation) + assert!(!outputs[j].contains(&outputs[i])); + } + } + } + + #[test] + fn test_stop_sequence_detection() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default().with_stop_sequence("test"); + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process "Hello world" + decoder.process_token(1).unwrap(); // "Hello" + decoder.process_token(2).unwrap(); // "world" + + // Process "test" which should trigger stop + let result = decoder.process_token(3).unwrap(); // "test" + + // Should stop when we hit "test" + assert!(matches!( + result, + SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_) + )); + } + + #[test] + fn test_flush_after_partial() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH"); + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process a token + decoder.process_token(1).unwrap(); // "Hello" + + // Flush should return any remaining text in jail + let result = decoder.flush(); + + // After processing, flush should work + assert!(matches!(result, SequenceDecoderOutput::Text(_))); + } + + #[test] + fn test_reset_functionality() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default().with_stop_token(999); + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process and stop + decoder.process_token(1).unwrap(); + decoder.process_token(999).unwrap(); + assert!(decoder.is_stopped()); + + // Reset should clear everything + decoder.reset(); + assert!(!decoder.is_stopped()); + + // Should be able to process again + let result = decoder.process_token(2).unwrap(); + assert!(matches!(result, SequenceDecoderOutput::Text(_))); + } + + #[test] + fn test_visible_stop_sequence() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default().with_visible_stop_sequence("world"); + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process "Hello" + decoder.process_token(1).unwrap(); + + // Process "world" - should include it in output + let result = decoder.process_token(2).unwrap(); + + if let SequenceDecoderOutput::StoppedWithText(text) = result { + // Should include "world" in the output + assert!(text.contains("world")); + } else { + panic!("Expected StoppedWithText with visible stop sequence"); + } + } + + #[test] + fn test_multiple_tokens_processing() { + let tokenizer = Arc::new(MockTokenizer::new()); + let config = StopSequenceConfig::default(); + let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); + + // Process multiple tokens at once + let results = decoder.process_tokens(&[1, 2, 3]).unwrap(); + + // Should get results for each token + assert_eq!(results.len(), 3); + + // Each result should be Text (no stops configured) + for result in results { + assert!(matches!( + result, + SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held + )); + } + } +}