[router] tokenizer factory, hf tokenizer, and stop sequence detector (#9293)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -3,6 +3,10 @@ name = "sglang_router_rs"
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[features]
|
||||
default = ["huggingface"]
|
||||
huggingface = ["tokenizers"]
|
||||
|
||||
[lib]
|
||||
name = "sglang_router_rs"
|
||||
# Pure Rust library: Just omit crate-type (defaults to rlib)
|
||||
@@ -44,7 +48,7 @@ thiserror = "2.0.12"
|
||||
url = "2.5.4"
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
tokenizers = "0.21.4"
|
||||
tokenizers = { version = "0.21.4", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
||||
228
sgl-router/src/tokenizer/factory.rs
Normal file
228
sgl-router/src/tokenizer/factory.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
use super::traits;
|
||||
use anyhow::{Error, Result};
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
use super::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
/// Represents the type of tokenizer being used
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TokenizerType {
|
||||
HuggingFace(String),
|
||||
Mock,
|
||||
// Future: SentencePiece, GGUF, Tiktoken
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path to a tokenizer file.
|
||||
/// The file extension is used to determine the tokenizer type.
|
||||
/// Supported file types are:
|
||||
/// - json: HuggingFace tokenizer
|
||||
/// - For testing: can return mock tokenizer
|
||||
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Special case for testing
|
||||
if file_path == "mock" || file_path == "test" {
|
||||
return Ok(Arc::new(super::mock::MockTokenizer::new()));
|
||||
}
|
||||
|
||||
let path = Path::new(file_path);
|
||||
|
||||
// Check if file exists
|
||||
if !path.exists() {
|
||||
return Err(Error::msg(format!("File not found: {}", file_path)));
|
||||
}
|
||||
|
||||
// Try to determine tokenizer type from extension
|
||||
let extension = path
|
||||
.extension()
|
||||
.and_then(std::ffi::OsStr::to_str)
|
||||
.map(|s| s.to_lowercase());
|
||||
|
||||
match extension.as_deref() {
|
||||
Some("json") => {
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
Ok(Arc::new(tokenizer))
|
||||
}
|
||||
#[cfg(not(feature = "huggingface"))]
|
||||
{
|
||||
Err(Error::msg(
|
||||
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
|
||||
))
|
||||
}
|
||||
}
|
||||
Some("model") => {
|
||||
// SentencePiece model file
|
||||
Err(Error::msg("SentencePiece models not yet supported"))
|
||||
}
|
||||
Some("gguf") => {
|
||||
// GGUF format
|
||||
Err(Error::msg("GGUF format not yet supported"))
|
||||
}
|
||||
_ => {
|
||||
// Try to auto-detect by reading file content
|
||||
auto_detect_tokenizer(file_path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Auto-detect tokenizer type by examining file content
|
||||
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
let mut file = File::open(file_path)?;
|
||||
let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
buffer.truncate(bytes_read);
|
||||
|
||||
// Check for JSON (HuggingFace format)
|
||||
if is_likely_json(&buffer) {
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
#[cfg(not(feature = "huggingface"))]
|
||||
{
|
||||
return Err(Error::msg(
|
||||
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for GGUF magic number
|
||||
if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
|
||||
return Err(Error::msg("GGUF format detected but not yet supported"));
|
||||
}
|
||||
|
||||
// Check for SentencePiece model
|
||||
if is_likely_sentencepiece(&buffer) {
|
||||
return Err(Error::msg(
|
||||
"SentencePiece model detected but not yet supported",
|
||||
));
|
||||
}
|
||||
|
||||
Err(Error::msg(format!(
|
||||
"Unable to determine tokenizer type for file: {}",
|
||||
file_path
|
||||
)))
|
||||
}
|
||||
|
||||
/// Check if the buffer likely contains JSON data
|
||||
fn is_likely_json(buffer: &[u8]) -> bool {
|
||||
// Skip UTF-8 BOM if present
|
||||
let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
|
||||
&buffer[3..]
|
||||
} else {
|
||||
buffer
|
||||
};
|
||||
|
||||
// Find first non-whitespace character without allocation
|
||||
if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
|
||||
*first_byte == b'{' || *first_byte == b'['
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the buffer likely contains a SentencePiece model
|
||||
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|
||||
// SentencePiece models often start with specific patterns
|
||||
// This is a simplified check
|
||||
buffer.len() >= 12
|
||||
&& (buffer.starts_with(b"\x0a\x09")
|
||||
|| buffer.starts_with(b"\x08\x00")
|
||||
|| buffer.windows(4).any(|w| w == b"<unk")
|
||||
|| buffer.windows(4).any(|w| w == b"<s>")
|
||||
|| buffer.windows(4).any(|w| w == b"</s>"))
|
||||
}
|
||||
|
||||
/// Factory function to create tokenizer from a model name or path
|
||||
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Check if it's a file path
|
||||
let path = Path::new(model_name_or_path);
|
||||
if path.exists() {
|
||||
return create_tokenizer_from_file(model_name_or_path);
|
||||
}
|
||||
|
||||
// Otherwise, try to load from HuggingFace Hub
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
// This would download from HF Hub - not implemented yet
|
||||
Err(Error::msg(
|
||||
"Loading from HuggingFace Hub not yet implemented",
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "huggingface"))]
|
||||
{
|
||||
Err(Error::msg(format!(
|
||||
"Model '{}' not found locally and HuggingFace support is not enabled",
|
||||
model_name_or_path
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Get information about a tokenizer file
|
||||
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
|
||||
let path = Path::new(file_path);
|
||||
|
||||
if !path.exists() {
|
||||
return Err(Error::msg(format!("File not found: {}", file_path)));
|
||||
}
|
||||
|
||||
let extension = path
|
||||
.extension()
|
||||
.and_then(std::ffi::OsStr::to_str)
|
||||
.map(|s| s.to_lowercase());
|
||||
|
||||
match extension.as_deref() {
|
||||
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
|
||||
_ => {
|
||||
// Try auto-detection
|
||||
use std::fs::File;
|
||||
use std::io::Read;
|
||||
|
||||
let mut file = File::open(file_path)?;
|
||||
let mut buffer = vec![0u8; 512];
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
buffer.truncate(bytes_read);
|
||||
|
||||
if is_likely_json(&buffer) {
|
||||
Ok(TokenizerType::HuggingFace(file_path.to_string()))
|
||||
} else {
|
||||
Err(Error::msg("Unknown tokenizer type"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_json_detection() {
|
||||
assert!(is_likely_json(b"{\"test\": \"value\"}"));
|
||||
assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
|
||||
assert!(is_likely_json(b"[1, 2, 3]"));
|
||||
assert!(!is_likely_json(b"not json"));
|
||||
assert!(!is_likely_json(b""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_creation() {
|
||||
let tokenizer = create_tokenizer_from_file("mock").unwrap();
|
||||
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_not_found() {
|
||||
let result = create_tokenizer_from_file("/nonexistent/file.json");
|
||||
assert!(result.is_err());
|
||||
if let Err(e) = result {
|
||||
assert!(e.to_string().contains("File not found"));
|
||||
}
|
||||
}
|
||||
}
|
||||
189
sgl-router/src/tokenizer/huggingface.rs
Normal file
189
sgl-router/src/tokenizer/huggingface.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use anyhow::{Error, Result};
|
||||
use std::collections::HashMap;
|
||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||
|
||||
/// HuggingFace tokenizer wrapper
|
||||
pub struct HuggingFaceTokenizer {
|
||||
tokenizer: HfTokenizer,
|
||||
special_tokens: SpecialTokens,
|
||||
vocab: HashMap<String, u32>,
|
||||
reverse_vocab: HashMap<u32, String>,
|
||||
}
|
||||
|
||||
impl HuggingFaceTokenizer {
|
||||
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||
let tokenizer = HfTokenizer::from_file(file_path)
|
||||
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||
|
||||
// Extract special tokens
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
|
||||
// Build vocab mappings
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
Ok(HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from an existing HuggingFace tokenizer
|
||||
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||
let vocab = tokenizer.get_vocab(false);
|
||||
let reverse_vocab: HashMap<u32, String> = vocab
|
||||
.iter()
|
||||
.map(|(token, &id)| (id, token.clone()))
|
||||
.collect();
|
||||
|
||||
HuggingFaceTokenizer {
|
||||
tokenizer,
|
||||
special_tokens,
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract special tokens from the tokenizer
|
||||
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
|
||||
// Try to get special tokens from the tokenizer
|
||||
// This is a simplified version - actual implementation would need to handle various formats
|
||||
let vocab = tokenizer.get_vocab(true);
|
||||
|
||||
let find_token = |patterns: &[&str]| -> Option<String> {
|
||||
for pattern in patterns {
|
||||
if vocab.contains_key(*pattern) {
|
||||
return Some(pattern.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
SpecialTokens {
|
||||
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
|
||||
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
|
||||
unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
|
||||
sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
|
||||
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
|
||||
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
|
||||
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
|
||||
additional_special_tokens: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply chat template if available
|
||||
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
|
||||
// This is a placeholder - actual implementation would handle templates
|
||||
let mut result = String::new();
|
||||
for msg in messages {
|
||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for HuggingFaceTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
let encoding = self
|
||||
.tokenizer
|
||||
.encode(input, false)
|
||||
.map_err(|e| Error::msg(format!("Encoding failed: {}", e)))?;
|
||||
|
||||
Ok(Encoding::Hf(Box::new(encoding)))
|
||||
}
|
||||
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
let encodings = self
|
||||
.tokenizer
|
||||
.encode_batch(inputs.to_vec(), false)
|
||||
.map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
|
||||
|
||||
Ok(encodings
|
||||
.into_iter()
|
||||
.map(|e| Encoding::Hf(Box::new(e)))
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for HuggingFaceTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
self.tokenizer
|
||||
.decode(token_ids, skip_special_tokens)
|
||||
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerTrait for HuggingFaceTokenizer {
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.tokenizer.get_vocab_size(false)
|
||||
}
|
||||
|
||||
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a chat message for template application
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
ChatMessage {
|
||||
role: role.into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self::new("system", content)
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self::new("user", content)
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self::new("assistant", content)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_creation() {
|
||||
let msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(msg.role, "system");
|
||||
assert_eq!(msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
}
|
||||
|
||||
// Note: Actual tokenizer tests would require a real tokenizer file
|
||||
// These would be integration tests rather than unit tests
|
||||
}
|
||||
@@ -2,26 +2,36 @@ use anyhow::Result;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod factory;
|
||||
pub mod mock;
|
||||
pub mod stop;
|
||||
pub mod stream;
|
||||
pub mod traits;
|
||||
|
||||
// Feature-gated modules
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub mod huggingface;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
// Re-exports
|
||||
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||
pub use stream::DecodeStream;
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
||||
|
||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||
#[derive(Clone)]
|
||||
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
||||
|
||||
impl Tokenizer {
|
||||
/// Create a tokenizer from a file path
|
||||
/// Will be implemented in Phase 3 with factory pattern
|
||||
pub fn from_file(_file_path: &str) -> Result<Tokenizer> {
|
||||
// TODO: Implement factory pattern in Phase 3
|
||||
unimplemented!("Factory pattern will be implemented in Phase 3")
|
||||
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
|
||||
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
||||
}
|
||||
|
||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||
|
||||
499
sgl-router/src/tokenizer/stop.rs
Normal file
499
sgl-router/src/tokenizer/stop.rs
Normal file
@@ -0,0 +1,499 @@
|
||||
use super::traits;
|
||||
use anyhow::Result;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Output from the sequence decoder
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum SequenceDecoderOutput {
|
||||
/// Normal text output
|
||||
Text(String),
|
||||
/// Text is being held due to partial stop sequence match
|
||||
Held,
|
||||
/// Stop sequence matched (hidden - not included in output)
|
||||
Stopped,
|
||||
/// Stop sequence matched with text (visible - included in output)
|
||||
StoppedWithText(String),
|
||||
}
|
||||
|
||||
/// Configuration for stop sequences
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct StopSequenceConfig {
|
||||
/// Token IDs that trigger a stop
|
||||
pub stop_tokens: HashSet<u32>,
|
||||
/// String sequences that trigger a stop
|
||||
pub stop_sequences: Vec<String>,
|
||||
/// Token IDs for visible stops (included in output)
|
||||
pub visible_stop_tokens: HashSet<u32>,
|
||||
/// String sequences for visible stops (included in output)
|
||||
pub visible_stop_sequences: Vec<String>,
|
||||
}
|
||||
|
||||
impl StopSequenceConfig {
|
||||
/// Builder pattern - add a stop token
|
||||
pub fn with_stop_token(mut self, token_id: u32) -> Self {
|
||||
self.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder pattern - add a stop sequence
|
||||
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder pattern - add a visible stop token
|
||||
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
|
||||
self.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builder pattern - add a visible stop sequence
|
||||
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.visible_stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Decoder that handles stop sequences
|
||||
pub struct StopSequenceDecoder {
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
config: StopSequenceConfig,
|
||||
/// Buffer for partial matches (the "jail")
|
||||
jail_buffer: String,
|
||||
/// Accumulated tokens
|
||||
token_buffer: Vec<u32>,
|
||||
/// Offset where the prefix text starts (for context)
|
||||
prefix_offset: usize,
|
||||
/// Offset marking the end of previously decoded text
|
||||
read_offset: usize,
|
||||
/// Whether we've stopped
|
||||
stopped: bool,
|
||||
skip_special_tokens: bool,
|
||||
}
|
||||
|
||||
impl StopSequenceDecoder {
|
||||
/// Create a new stop sequence decoder
|
||||
pub fn new(
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
config: StopSequenceConfig,
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
StopSequenceDecoder {
|
||||
tokenizer,
|
||||
config,
|
||||
jail_buffer: String::new(),
|
||||
token_buffer: Vec::new(),
|
||||
prefix_offset: 0,
|
||||
read_offset: 0,
|
||||
stopped: false,
|
||||
skip_special_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single token
|
||||
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
|
||||
if self.stopped {
|
||||
return Ok(SequenceDecoderOutput::Stopped);
|
||||
}
|
||||
|
||||
// Check for token-level stops first
|
||||
if self.config.stop_tokens.contains(&token_id) {
|
||||
self.stopped = true;
|
||||
// Flush any jailed text before stopping
|
||||
if !self.jail_buffer.is_empty() {
|
||||
let output = self.jail_buffer.clone();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
return Ok(SequenceDecoderOutput::Stopped);
|
||||
}
|
||||
|
||||
if self.config.visible_stop_tokens.contains(&token_id) {
|
||||
self.stopped = true;
|
||||
// Include jailed text plus the stop token
|
||||
let stop_text = self
|
||||
.tokenizer
|
||||
.decode(&[token_id], self.skip_special_tokens)?;
|
||||
let output = format!("{}{}", self.jail_buffer, stop_text);
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
|
||||
// Add token to buffer
|
||||
self.token_buffer.push(token_id);
|
||||
|
||||
// Use incremental decoding like DecodeStream
|
||||
// First decode the previous context (what we've already output)
|
||||
let prefix_text = if self.read_offset > self.prefix_offset {
|
||||
self.tokenizer.decode(
|
||||
&self.token_buffer[self.prefix_offset..self.read_offset],
|
||||
self.skip_special_tokens,
|
||||
)?
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Now decode from prefix to current position
|
||||
let new_full_text = self.tokenizer.decode(
|
||||
&self.token_buffer[self.prefix_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
// Check for incomplete UTF-8 sequence
|
||||
if new_full_text.ends_with("<EFBFBD>") {
|
||||
// Wait for more tokens to complete the sequence
|
||||
return Ok(SequenceDecoderOutput::Held);
|
||||
}
|
||||
|
||||
// Calculate only the NEW text since last successful decode
|
||||
let new_text = if new_full_text.len() > prefix_text.len() {
|
||||
&new_full_text[prefix_text.len()..]
|
||||
} else {
|
||||
// No new text produced (can happen with special tokens)
|
||||
return Ok(SequenceDecoderOutput::Held);
|
||||
};
|
||||
|
||||
// Combine jail buffer with new text for checking
|
||||
let check_text = format!("{}{}", self.jail_buffer, new_text);
|
||||
|
||||
// Check for complete stop sequences
|
||||
for stop_seq in &self.config.stop_sequences {
|
||||
if let Some(pos) = check_text.find(stop_seq) {
|
||||
self.stopped = true;
|
||||
// Output text before the stop sequence
|
||||
let output = check_text[..pos].to_string();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(if output.is_empty() {
|
||||
SequenceDecoderOutput::Stopped
|
||||
} else {
|
||||
SequenceDecoderOutput::StoppedWithText(output)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Check for visible stop sequences
|
||||
for stop_seq in &self.config.visible_stop_sequences {
|
||||
if let Some(pos) = check_text.find(stop_seq) {
|
||||
self.stopped = true;
|
||||
// Include the stop sequence in output
|
||||
let end_pos = pos + stop_seq.len();
|
||||
let output = check_text[..end_pos].to_string();
|
||||
self.jail_buffer.clear();
|
||||
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial matches at the end of check_text
|
||||
let mut partial_match_len = 0;
|
||||
for stop_seq in self
|
||||
.config
|
||||
.stop_sequences
|
||||
.iter()
|
||||
.chain(&self.config.visible_stop_sequences)
|
||||
{
|
||||
// Check all possible suffixes that could be a prefix of stop_seq
|
||||
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
|
||||
let suffix = &check_text[check_text.len() - i..];
|
||||
if stop_seq.starts_with(suffix) {
|
||||
partial_match_len = partial_match_len.max(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if partial_match_len > 0 {
|
||||
// Split: output safe text, jail the potential match
|
||||
let safe_end = check_text.len() - partial_match_len;
|
||||
let safe_text = &check_text[..safe_end];
|
||||
self.jail_buffer = check_text[safe_end..].to_string();
|
||||
|
||||
// Update offsets for next iteration
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_buffer.len();
|
||||
|
||||
if safe_text.is_empty() {
|
||||
Ok(SequenceDecoderOutput::Held)
|
||||
} else {
|
||||
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
|
||||
}
|
||||
} else {
|
||||
// No partial matches - output everything
|
||||
self.jail_buffer.clear();
|
||||
|
||||
// Update offsets for next iteration
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.token_buffer.len();
|
||||
|
||||
Ok(SequenceDecoderOutput::Text(check_text))
|
||||
}
|
||||
}
|
||||
|
||||
/// Process multiple tokens
|
||||
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
|
||||
let mut outputs = Vec::new();
|
||||
for &token_id in token_ids {
|
||||
outputs.push(self.process_token(token_id)?);
|
||||
}
|
||||
Ok(outputs)
|
||||
}
|
||||
|
||||
/// Flush any held text
|
||||
pub fn flush(&mut self) -> SequenceDecoderOutput {
|
||||
if !self.jail_buffer.is_empty() {
|
||||
let output = self.jail_buffer.clone();
|
||||
self.jail_buffer.clear();
|
||||
SequenceDecoderOutput::Text(output)
|
||||
} else {
|
||||
SequenceDecoderOutput::Text(String::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if decoding has stopped
|
||||
pub fn is_stopped(&self) -> bool {
|
||||
self.stopped
|
||||
}
|
||||
|
||||
/// Reset the decoder state
|
||||
pub fn reset(&mut self) {
|
||||
self.jail_buffer.clear();
|
||||
self.token_buffer.clear();
|
||||
self.prefix_offset = 0;
|
||||
self.read_offset = 0;
|
||||
self.stopped = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for StopSequenceDecoder
|
||||
pub struct StopSequenceDecoderBuilder {
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
config: StopSequenceConfig,
|
||||
skip_special_tokens: bool,
|
||||
}
|
||||
|
||||
impl StopSequenceDecoderBuilder {
|
||||
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
StopSequenceDecoderBuilder {
|
||||
tokenizer,
|
||||
config: StopSequenceConfig::default(),
|
||||
skip_special_tokens: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn stop_token(mut self, token_id: u32) -> Self {
|
||||
self.config.stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.config.stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
|
||||
self.config.visible_stop_tokens.insert(token_id);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||
self.config.visible_stop_sequences.push(sequence.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
|
||||
self.skip_special_tokens = skip;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> StopSequenceDecoder {
|
||||
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tokenizer::mock::MockTokenizer;
|
||||
|
||||
#[test]
|
||||
fn test_stop_token_detection() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
|
||||
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process tokens before stop
|
||||
let result = decoder.process_token(1).unwrap(); // "Hello"
|
||||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||
|
||||
// Process stop token
|
||||
let result = decoder.process_token(999).unwrap(); // <eos>
|
||||
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||||
|
||||
// Further tokens should also return Stopped
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visible_stop_token() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_visible_stop_token(999);
|
||||
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
let result = decoder.process_token(999).unwrap();
|
||||
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_pattern() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
|
||||
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
|
||||
.stop_token(999)
|
||||
.stop_sequence("STOP")
|
||||
.visible_stop_token(1000)
|
||||
.skip_special_tokens(true)
|
||||
.build();
|
||||
|
||||
assert!(!decoder.is_stopped());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_incremental_decoding_no_repetition() {
|
||||
// This test verifies the critical fix: no repeated output
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default();
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process tokens one by one and collect outputs
|
||||
let mut outputs = Vec::new();
|
||||
|
||||
// Token 1: "Hello"
|
||||
let result = decoder.process_token(1).unwrap();
|
||||
if let SequenceDecoderOutput::Text(text) = result {
|
||||
outputs.push(text.clone());
|
||||
}
|
||||
|
||||
// Token 2: "world"
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
if let SequenceDecoderOutput::Text(text) = result {
|
||||
outputs.push(text.clone());
|
||||
}
|
||||
|
||||
// Token 3: "test"
|
||||
let result = decoder.process_token(3).unwrap();
|
||||
if let SequenceDecoderOutput::Text(text) = result {
|
||||
outputs.push(text.clone());
|
||||
}
|
||||
|
||||
// CRITICAL: Each output should be unique (no accumulation)
|
||||
// The fix ensures we only output NEW text, not accumulated text
|
||||
assert_eq!(outputs.len(), 3);
|
||||
|
||||
// Verify no text is repeated
|
||||
for i in 0..outputs.len() {
|
||||
for j in i + 1..outputs.len() {
|
||||
// No output should contain another (no accumulation)
|
||||
assert!(!outputs[j].contains(&outputs[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_sequence_detection() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_sequence("test");
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process "Hello world"
|
||||
decoder.process_token(1).unwrap(); // "Hello"
|
||||
decoder.process_token(2).unwrap(); // "world"
|
||||
|
||||
// Process "test" which should trigger stop
|
||||
let result = decoder.process_token(3).unwrap(); // "test"
|
||||
|
||||
// Should stop when we hit "test"
|
||||
assert!(matches!(
|
||||
result,
|
||||
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_after_partial() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process a token
|
||||
decoder.process_token(1).unwrap(); // "Hello"
|
||||
|
||||
// Flush should return any remaining text in jail
|
||||
let result = decoder.flush();
|
||||
|
||||
// After processing, flush should work
|
||||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reset_functionality() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_stop_token(999);
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process and stop
|
||||
decoder.process_token(1).unwrap();
|
||||
decoder.process_token(999).unwrap();
|
||||
assert!(decoder.is_stopped());
|
||||
|
||||
// Reset should clear everything
|
||||
decoder.reset();
|
||||
assert!(!decoder.is_stopped());
|
||||
|
||||
// Should be able to process again
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visible_stop_sequence() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process "Hello"
|
||||
decoder.process_token(1).unwrap();
|
||||
|
||||
// Process "world" - should include it in output
|
||||
let result = decoder.process_token(2).unwrap();
|
||||
|
||||
if let SequenceDecoderOutput::StoppedWithText(text) = result {
|
||||
// Should include "world" in the output
|
||||
assert!(text.contains("world"));
|
||||
} else {
|
||||
panic!("Expected StoppedWithText with visible stop sequence");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_tokens_processing() {
|
||||
let tokenizer = Arc::new(MockTokenizer::new());
|
||||
let config = StopSequenceConfig::default();
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||
|
||||
// Process multiple tokens at once
|
||||
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
|
||||
|
||||
// Should get results for each token
|
||||
assert_eq!(results.len(), 3);
|
||||
|
||||
// Each result should be Text (no stops configured)
|
||||
for result in results {
|
||||
assert!(matches!(
|
||||
result,
|
||||
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user