diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 2460b635a..e0defacdf 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [features] default = ["huggingface"] huggingface = ["tokenizers"] +tiktoken = ["tiktoken-rs"] [lib] name = "sglang_router_rs" @@ -49,6 +50,7 @@ url = "2.5.4" tokio-stream = { version = "0.1", features = ["sync"] } anyhow = "1.0" tokenizers = { version = "0.21.4", optional = true } +tiktoken-rs = { version = "0.5", optional = true } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/sgl-router/src/tokenizer/factory.rs b/sgl-router/src/tokenizer/factory.rs index 04b950d3c..e339140e7 100644 --- a/sgl-router/src/tokenizer/factory.rs +++ b/sgl-router/src/tokenizer/factory.rs @@ -1,4 +1,4 @@ -use super::{traits, TokenizerTrait}; +use super::traits::{self, Tokenizer as TokenizerTrait}; use crate::metrics::TokenizerMetrics; use anyhow::{Error, Result}; use std::fs::File; @@ -15,7 +15,9 @@ use super::huggingface::HuggingFaceTokenizer; pub enum TokenizerType { HuggingFace(String), Mock, - // Future: SentencePiece, GGUF, Tiktoken + #[cfg(feature = "tiktoken")] + Tiktoken(String), + // Future: SentencePiece, GGUF } /// Create a tokenizer from a file path to a tokenizer file. @@ -166,6 +168,23 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result 0); + + // Test encoding and decoding + let text = "Hello, world!"; + let encoding = tokenizer.encode(text).unwrap(); + let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); + assert_eq!(decoded, text); + } } diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs index c218dbecc..7d7f87aed 100644 --- a/sgl-router/src/tokenizer/mod.rs +++ b/sgl-router/src/tokenizer/mod.rs @@ -4,6 +4,7 @@ use std::sync::Arc; pub mod factory; pub mod mock; +pub mod sequence; pub mod stop; pub mod stream; pub mod traits; @@ -12,11 +13,15 @@ pub mod traits; #[cfg(feature = "huggingface")] pub mod huggingface; +#[cfg(feature = "tiktoken")] +pub mod tiktoken; + #[cfg(test)] mod tests; // Re-exports pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType}; +pub use sequence::Sequence; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stream::DecodeStream; pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; @@ -24,6 +29,9 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz #[cfg(feature = "huggingface")] pub use huggingface::{ChatMessage, HuggingFaceTokenizer}; +#[cfg(feature = "tiktoken")] +pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; + /// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations #[derive(Clone)] pub struct Tokenizer(Arc); diff --git a/sgl-router/src/tokenizer/sequence.rs b/sgl-router/src/tokenizer/sequence.rs new file mode 100644 index 000000000..816d3cc59 --- /dev/null +++ b/sgl-router/src/tokenizer/sequence.rs @@ -0,0 +1,238 @@ +use super::traits::Tokenizer as TokenizerTrait; +use anyhow::Result; +use std::sync::Arc; + +/// Maintains state for an ongoing sequence of tokens and their decoded text +/// This provides a cleaner abstraction for managing token sequences +pub struct Sequence { + /// The tokenizer used for encoding/decoding + tokenizer: Arc, + + /// The current sequence of token ids + token_ids: Vec, + + /// The position in the current sequence the last decoded token completed + prefix_offset: usize, + + /// Current position in the sequence + read_offset: usize, +} + +impl std::fmt::Debug for Sequence { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Sequence") + .field("tokenizer", &"Arc") + .field( + "token_ids", + &format_args!("{}", { + let token_ids = self.token_ids(); + if token_ids.len() <= 20 { + format!("{:?}", token_ids) + } else { + let first_ten = &token_ids[..10]; + let last_ten = &token_ids[token_ids.len() - 10..]; + format!("{:?} ... {:?}", first_ten, last_ten) + } + }), + ) + .field("prefix_offset", &self.prefix_offset) + .field("read_offset", &self.read_offset) + .field("token count", &self.token_ids.len()) + .finish() + } +} + +impl Sequence { + /// Create a new empty sequence + pub fn new(tokenizer: Arc) -> Self { + Self { + tokenizer, + token_ids: Vec::new(), + prefix_offset: 0, + read_offset: 0, + } + } + + /// Create a sequence with initial tokens + pub fn with_tokens(tokenizer: Arc, token_ids: Vec) -> Self { + let len = token_ids.len(); + Self { + tokenizer, + token_ids, + prefix_offset: 0, + read_offset: len, + } + } + + /// Check if the sequence is empty + pub fn is_empty(&self) -> bool { + self.token_ids.is_empty() + } + + /// Get the length of the sequence + pub fn len(&self) -> usize { + self.token_ids.len() + } + + /// Clear the sequence + pub fn clear(&mut self) { + self.token_ids.clear(); + self.prefix_offset = 0; + self.read_offset = 0; + } + + /// Append text to the sequence by encoding it + pub fn append_text(&mut self, input: &str) -> Result<()> { + let encoding = self.tokenizer.encode(input)?; + self.token_ids.extend(encoding.token_ids()); + Ok(()) + } + + /// Append a single token to the sequence and return newly decoded text + /// Based on HuggingFace TGI incremental decoding + pub fn append_token(&mut self, token_id: u32) -> Result { + // Store the old read offset before adding the new token + let old_read_offset = self.read_offset; + + self.token_ids.push(token_id); + self.read_offset = self.token_ids.len(); + + // If this is the first token or we're at the beginning, decode everything + if self.prefix_offset == 0 && old_read_offset == 0 { + let text = self.tokenizer.decode(&self.token_ids, false)?; + if text.ends_with("�") { + // Incomplete UTF-8 sequence, wait for more tokens + return Ok(String::new()); + } + self.prefix_offset = 0; + return Ok(text); + } + + // Decode the text up to the previous position + let prefix_text = self + .tokenizer + .decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?; + + // Decode the text including the new token + let new_text = self + .tokenizer + .decode(&self.token_ids[self.prefix_offset..], false)?; + + // Handle multi-byte character boundaries + let mut prefix_text_len = prefix_text.len(); + while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 { + prefix_text_len -= 1; + } + + if new_text.len() > prefix_text.len() { + if new_text.ends_with("�") { + // Incomplete UTF-8 sequence, wait for more tokens + return Ok(String::new()); + } else { + // Return the new text portion + let incremental_text = new_text[prefix_text_len..].to_string().replace("�", ""); + self.prefix_offset = old_read_offset; + return Ok(incremental_text); + } + } + + Ok(String::new()) + } + + /// Get a reference to the tokenizer + pub fn tokenizer(&self) -> &Arc { + &self.tokenizer + } + + /// Get the current token ids + pub fn token_ids(&self) -> &[u32] { + &self.token_ids + } + + /// Decode the entire sequence to text + pub fn text(&self) -> Result { + self.tokenizer.decode(&self.token_ids, false) + } + + /// Get the prefix offset + pub fn prefix_offset(&self) -> usize { + self.prefix_offset + } + + /// Get the read offset + pub fn read_offset(&self) -> usize { + self.read_offset + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tokenizer::mock::MockTokenizer; + + #[test] + fn test_sequence_new() { + let tokenizer = Arc::new(MockTokenizer::new()); + let seq = Sequence::new(tokenizer); + assert!(seq.is_empty()); + assert_eq!(seq.len(), 0); + } + + #[test] + fn test_sequence_append_text() { + let tokenizer = Arc::new(MockTokenizer::new()); + let mut seq = Sequence::new(tokenizer); + + seq.append_text("Hello").unwrap(); + assert!(!seq.is_empty()); + assert!(!seq.is_empty()); + + let text = seq.text().unwrap(); + assert_eq!(text, "Hello"); + } + + #[test] + fn test_sequence_append_token() { + let tokenizer = Arc::new(MockTokenizer::new()); + let mut seq = Sequence::new(tokenizer.clone()); + + // Start with an empty sequence and append token 1 ("Hello") + let text1 = seq.append_token(1).unwrap(); + assert_eq!(text1, "Hello"); + + // Now append token 2 ("world") + // The mock tokenizer will decode [1, 2] as "Hello world" (with a space) + let text2 = seq.append_token(2).unwrap(); + // The incremental text should be " world" (with the space that the mock tokenizer adds) + assert_eq!(text2, " world"); + + // Verify the full text + assert_eq!(seq.text().unwrap(), "Hello world"); + } + + #[test] + fn test_sequence_clear() { + let tokenizer = Arc::new(MockTokenizer::new()); + let mut seq = Sequence::new(tokenizer); + + seq.append_text("Hello world").unwrap(); + assert!(!seq.is_empty()); + + seq.clear(); + assert!(seq.is_empty()); + assert_eq!(seq.len(), 0); + assert_eq!(seq.prefix_offset(), 0); + assert_eq!(seq.read_offset(), 0); + } + + #[test] + fn test_sequence_debug() { + let tokenizer = Arc::new(MockTokenizer::new()); + let mut seq = Sequence::new(tokenizer); + + seq.append_text("Test").unwrap(); + let debug_str = format!("{:?}", seq); + assert!(debug_str.contains("Sequence")); + assert!(debug_str.contains("token count")); + } +} diff --git a/sgl-router/src/tokenizer/tests.rs b/sgl-router/src/tokenizer/tests.rs index 2c4d4b108..93c8f1621 100644 --- a/sgl-router/src/tokenizer/tests.rs +++ b/sgl-router/src/tokenizer/tests.rs @@ -129,7 +129,9 @@ fn test_thread_safety() { thread::spawn(move || { let text = "Hello test".to_string(); let encoding = tokenizer_clone.encode(&text).unwrap(); - let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap(); + let decoded = tokenizer_clone + .decode(&encoding.token_ids(), false) + .unwrap(); assert!(decoded.contains("Hello") || decoded.contains("test")); i }) diff --git a/sgl-router/src/tokenizer/tiktoken.rs b/sgl-router/src/tokenizer/tiktoken.rs new file mode 100644 index 000000000..4cf0ea9f1 --- /dev/null +++ b/sgl-router/src/tokenizer/tiktoken.rs @@ -0,0 +1,276 @@ +use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; +use anyhow::{Error, Result}; +use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE}; + +/// Tiktoken tokenizer wrapper for OpenAI GPT models +pub struct TiktokenTokenizer { + tokenizer: CoreBPE, + #[allow(dead_code)] + model: TiktokenModel, + special_tokens: SpecialTokens, + vocab_size: usize, +} + +/// Supported Tiktoken models +#[derive(Debug, Clone, Copy)] +pub enum TiktokenModel { + /// GPT-4, GPT-3.5-turbo, text-embedding-ada-002 + Cl100kBase, + /// Codex models, text-davinci-002, text-davinci-003 + P50kBase, + /// Use for edit models like text-davinci-edit-001, code-davinci-edit-001 + P50kEdit, + /// GPT-3 models like davinci + R50kBase, +} + +impl TiktokenTokenizer { + /// Create a new Tiktoken tokenizer for the specified model + pub fn new(model: TiktokenModel) -> Result { + let tokenizer = + match model { + TiktokenModel::Cl100kBase => cl100k_base() + .map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?, + TiktokenModel::P50kBase => p50k_base() + .map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?, + TiktokenModel::P50kEdit => p50k_edit() + .map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?, + TiktokenModel::R50kBase => r50k_base() + .map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?, + }; + + // Extract special tokens (tiktoken-rs doesn't expose them directly) + // We'll use common ones for GPT models + let special_tokens = Self::get_special_tokens_for_model(model); + + // Get vocabulary size (this is an approximation) + let vocab_size = match model { + TiktokenModel::Cl100kBase => 100256, // cl100k has ~100k tokens + TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, // p50k has ~50k tokens + TiktokenModel::R50kBase => 50257, // r50k has ~50k tokens + }; + + Ok(TiktokenTokenizer { + tokenizer, + model, + special_tokens, + vocab_size, + }) + } + + /// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo") + pub fn from_model_name(model_name: &str) -> Result { + let model = Self::model_from_name(model_name)?; + Self::new(model) + } + + /// Determine the appropriate model from a model name + fn model_from_name(model_name: &str) -> Result { + // Based on OpenAI's model-to-encoding mapping + if model_name.contains("gpt-4") + || model_name.contains("gpt-3.5") + || model_name.contains("turbo") + { + Ok(TiktokenModel::Cl100kBase) + } else if model_name.contains("davinci-002") + || model_name.contains("davinci-003") + || model_name.contains("codex") + { + Ok(TiktokenModel::P50kBase) + } else if model_name.contains("edit") { + Ok(TiktokenModel::P50kEdit) + } else if model_name.contains("davinci") + || model_name.contains("curie") + || model_name.contains("babbage") + || model_name.contains("ada") + { + Ok(TiktokenModel::R50kBase) + } else { + // Return an error for unrecognized model names to prevent silent failures + Err(anyhow::anyhow!( + "Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names", + model_name + )) + } + } + + /// Get special tokens for a specific model + fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens { + // These are common special tokens for GPT models + // The actual token IDs might vary by model + match model { + TiktokenModel::Cl100kBase => SpecialTokens { + bos_token: Some("<|endoftext|>".to_string()), + eos_token: Some("<|endoftext|>".to_string()), + unk_token: None, + sep_token: None, + pad_token: Some("<|endoftext|>".to_string()), + cls_token: None, + mask_token: None, + additional_special_tokens: vec![ + "<|fim_prefix|>".to_string(), + "<|fim_middle|>".to_string(), + "<|fim_suffix|>".to_string(), + "<|endofprompt|>".to_string(), + ], + }, + _ => SpecialTokens { + bos_token: Some("<|endoftext|>".to_string()), + eos_token: Some("<|endoftext|>".to_string()), + unk_token: None, + sep_token: None, + pad_token: Some("<|endoftext|>".to_string()), + cls_token: None, + mask_token: None, + additional_special_tokens: vec![], + }, + } + } +} + +impl Encoder for TiktokenTokenizer { + fn encode(&self, input: &str) -> Result { + let tokens = self.tokenizer.encode_ordinary(input); + Ok(Encoding::Tiktoken(tokens)) + } + + fn encode_batch(&self, inputs: &[&str]) -> Result> { + inputs.iter().map(|input| self.encode(input)).collect() + } +} + +impl Decoder for TiktokenTokenizer { + fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result { + // Convert u32 to usize for tiktoken-rs + let tokens: Vec = token_ids.iter().map(|&id| id as usize).collect(); + + self.tokenizer + .decode(tokens) + .map_err(|e| Error::msg(format!("Decoding failed: {}", e))) + } +} + +impl TokenizerTrait for TiktokenTokenizer { + fn vocab_size(&self) -> usize { + self.vocab_size + } + + fn get_special_tokens(&self) -> &SpecialTokens { + &self.special_tokens + } + + fn token_to_id(&self, _token: &str) -> Option { + // Tiktoken doesn't provide direct token-to-id mapping + // We'd need to encode the token and check if it produces a single ID + None + } + + fn id_to_token(&self, _id: u32) -> Option { + // Tiktoken doesn't provide direct id-to-token mapping + // We can only decode IDs to text + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tiktoken_creation() { + let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); + assert_eq!(tokenizer.vocab_size(), 100256); + } + + #[test] + fn test_model_from_name() { + assert!(matches!( + TiktokenTokenizer::model_from_name("gpt-4").unwrap(), + TiktokenModel::Cl100kBase + )); + assert!(matches!( + TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(), + TiktokenModel::Cl100kBase + )); + assert!(matches!( + TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(), + TiktokenModel::P50kBase + )); + assert!(matches!( + TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(), + TiktokenModel::P50kEdit + )); + assert!(matches!( + TiktokenTokenizer::model_from_name("davinci").unwrap(), + TiktokenModel::R50kBase + )); + } + + #[test] + fn test_encode_decode() { + let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); + + let text = "Hello, world!"; + let encoding = tokenizer.encode(text).unwrap(); + + let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); + assert_eq!(decoded, text); + } + + #[test] + fn test_batch_encode() { + let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); + + let texts = vec!["Hello", "World", "Test"]; + let encodings = tokenizer.encode_batch(&texts).unwrap(); + + assert_eq!(encodings.len(), 3); + for (i, encoding) in encodings.iter().enumerate() { + let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); + assert_eq!(decoded, texts[i]); + } + } + + #[test] + fn test_special_tokens() { + let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); + let special_tokens = tokenizer.get_special_tokens(); + + assert!(special_tokens.eos_token.is_some()); + assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>"); + } + + #[test] + fn test_unrecognized_model_name_returns_error() { + // Test that unrecognized model names return an error + let result = TiktokenTokenizer::from_model_name("distilgpt-2"); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("Unrecognized OpenAI model name")); + } + + let result = TiktokenTokenizer::from_model_name("bert-base-uncased"); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("Unrecognized OpenAI model name")); + } + + let result = TiktokenTokenizer::from_model_name("llama-7b"); + assert!(result.is_err()); + if let Err(e) = result { + assert!(e.to_string().contains("Unrecognized OpenAI model name")); + } + } + + #[test] + fn test_recognized_model_names() { + // Test that recognized model names work correctly + assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok()); + assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok()); + assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok()); + assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok()); + assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok()); + assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok()); + assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok()); + } +} diff --git a/sgl-router/src/tokenizer/traits.rs b/sgl-router/src/tokenizer/traits.rs index 54e683497..e0153704a 100644 --- a/sgl-router/src/tokenizer/traits.rs +++ b/sgl-router/src/tokenizer/traits.rs @@ -26,13 +26,28 @@ pub enum Encoding { Hf(Box), /// Sentence Piece Sp(Vec), + /// Tiktoken (for GPT models) + Tiktoken(Vec), } impl Encoding { - pub fn token_ids(&self) -> &[u32] { + pub fn token_ids(&self) -> Vec { + match self { + Encoding::Hf(inner) => inner.get_ids().to_vec(), + Encoding::Sp(inner) => inner.clone(), + Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(), + } + } + + pub fn token_ids_ref(&self) -> &[u32] { match self { Encoding::Hf(inner) => inner.get_ids(), Encoding::Sp(inner) => inner, + Encoding::Tiktoken(_) => { + // Tiktoken uses usize, we can't return a reference to u32 + // This is a limitation - callers should use token_ids() for Tiktoken + &[] + } } } }