From ff0cf51c8ea2928743cf0831ba27f24f5c7098c9 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Sun, 17 Aug 2025 16:30:01 -0700 Subject: [PATCH] [router] introducing tokenizer trait (#9287) --- sgl-router/Cargo.toml | 2 + sgl-router/src/lib.rs | 1 + sgl-router/src/tokenizer/mock.rs | 112 ++++++++++++++++++++++ sgl-router/src/tokenizer/mod.rs | 89 ++++++++++++++++++ sgl-router/src/tokenizer/stream.rs | 105 +++++++++++++++++++++ sgl-router/src/tokenizer/tests.rs | 143 +++++++++++++++++++++++++++++ sgl-router/src/tokenizer/traits.rs | 50 ++++++++++ 7 files changed, 502 insertions(+) create mode 100644 sgl-router/src/tokenizer/mock.rs create mode 100644 sgl-router/src/tokenizer/mod.rs create mode 100644 sgl-router/src/tokenizer/stream.rs create mode 100644 sgl-router/src/tokenizer/tests.rs create mode 100644 sgl-router/src/tokenizer/traits.rs diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 44691b200..71c2e7ccb 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -43,6 +43,8 @@ uuid = { version = "1.10", features = ["v4", "serde"] } thiserror = "2.0.12" url = "2.5.4" tokio-stream = { version = "0.1", features = ["sync"] } +anyhow = "1.0" +tokenizers = "0.21.4" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 2d8641a9d..299dfdcfa 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -10,6 +10,7 @@ pub mod policies; pub mod routers; pub mod server; pub mod service_discovery; +pub mod tokenizer; pub mod tree; use crate::metrics::PrometheusConfig; diff --git a/sgl-router/src/tokenizer/mock.rs b/sgl-router/src/tokenizer/mock.rs new file mode 100644 index 000000000..afb91543c --- /dev/null +++ b/sgl-router/src/tokenizer/mock.rs @@ -0,0 +1,112 @@ +//! Mock tokenizer implementation for testing + +use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; +use anyhow::Result; +use std::collections::HashMap; + +/// Mock tokenizer for testing purposes +pub struct MockTokenizer { + vocab: HashMap, + reverse_vocab: HashMap, + special_tokens: SpecialTokens, +} + +impl Default for MockTokenizer { + fn default() -> Self { + Self::new() + } +} + +impl MockTokenizer { + pub fn new() -> Self { + let mut vocab = HashMap::new(); + let mut reverse_vocab = HashMap::new(); + + // Add some basic tokens + let tokens = vec![ + ("Hello", 1), + ("world", 2), + ("test", 3), + ("token", 4), + (" ", 5), + (".", 6), + ("", 999), + ("", 1000), + ]; + + for (token, id) in tokens { + vocab.insert(token.to_string(), id); + reverse_vocab.insert(id, token.to_string()); + } + + let special_tokens = SpecialTokens { + bos_token: Some("".to_string()), + eos_token: Some("".to_string()), + unk_token: Some("".to_string()), + sep_token: None, + pad_token: None, + cls_token: None, + mask_token: None, + additional_special_tokens: vec![], + }; + + Self { + vocab, + reverse_vocab, + special_tokens, + } + } +} + +impl Encoder for MockTokenizer { + fn encode(&self, input: &str) -> Result { + // Simple word-based tokenization for testing + let tokens: Vec = input + .split_whitespace() + .filter_map(|word| self.vocab.get(word).copied()) + .collect(); + + Ok(Encoding::Sp(tokens)) + } + + fn encode_batch(&self, inputs: &[&str]) -> Result> { + inputs.iter().map(|input| self.encode(input)).collect() + } +} + +impl Decoder for MockTokenizer { + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + let tokens: Vec = token_ids + .iter() + .filter_map(|id| { + self.reverse_vocab.get(id).and_then(|token| { + if skip_special_tokens && (token == "" || token == "") { + None + } else { + Some(token.clone()) + } + }) + }) + .collect(); + + Ok(tokens.join(" ")) + } +} + +impl TokenizerTrait for MockTokenizer { + fn vocab_size(&self) -> usize { + self.vocab.len() + } + + fn get_special_tokens(&self) -> &SpecialTokens { + &self.special_tokens + } + + fn token_to_id(&self, token: &str) -> Option { + self.vocab.get(token).copied() + } + + fn id_to_token(&self, id: u32) -> Option { + self.reverse_vocab.get(&id).cloned() + } +} diff --git a/sgl-router/src/tokenizer/mod.rs b/sgl-router/src/tokenizer/mod.rs new file mode 100644 index 000000000..a77884abe --- /dev/null +++ b/sgl-router/src/tokenizer/mod.rs @@ -0,0 +1,89 @@ +use anyhow::Result; +use std::ops::Deref; +use std::sync::Arc; + +pub mod mock; +pub mod stream; +pub mod traits; + +#[cfg(test)] +mod tests; + +pub use stream::DecodeStream; +pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; + +/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations +#[derive(Clone)] +pub struct Tokenizer(Arc); + +impl Tokenizer { + /// Create a tokenizer from a file path + /// Will be implemented in Phase 3 with factory pattern + pub fn from_file(_file_path: &str) -> Result { + // TODO: Implement factory pattern in Phase 3 + unimplemented!("Factory pattern will be implemented in Phase 3") + } + + /// Create a tokenizer from an Arc + pub fn from_arc(tokenizer: Arc) -> Self { + Tokenizer(tokenizer) + } + + /// Create a stateful sequence object for decoding token_ids into text + pub fn decode_stream( + &self, + prompt_token_ids: &[u32], + skip_special_tokens: bool, + ) -> DecodeStream { + DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens) + } + + /// Direct encode method + pub fn encode(&self, input: &str) -> Result { + self.0.encode(input) + } + + /// Direct batch encode method + pub fn encode_batch(&self, inputs: &[&str]) -> Result> { + self.0.encode_batch(inputs) + } + + /// Direct decode method + pub fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result { + self.0.decode(token_ids, skip_special_tokens) + } + + /// Get vocabulary size + pub fn vocab_size(&self) -> usize { + self.0.vocab_size() + } + + /// Get special tokens + pub fn get_special_tokens(&self) -> &SpecialTokens { + self.0.get_special_tokens() + } + + /// Convert token string to ID + pub fn token_to_id(&self, token: &str) -> Option { + self.0.token_to_id(token) + } + + /// Convert ID to token string + pub fn id_to_token(&self, id: u32) -> Option { + self.0.id_to_token(id) + } +} + +impl Deref for Tokenizer { + type Target = Arc; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From> for Tokenizer { + fn from(tokenizer: Arc) -> Self { + Tokenizer(tokenizer) + } +} diff --git a/sgl-router/src/tokenizer/stream.rs b/sgl-router/src/tokenizer/stream.rs new file mode 100644 index 000000000..6b236b03f --- /dev/null +++ b/sgl-router/src/tokenizer/stream.rs @@ -0,0 +1,105 @@ +// src/tokenizer/stream.rs + +use super::traits; +use anyhow::Result; +use std::sync::Arc; + +const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5; + +/// DecodeStream will keep the state necessary to produce individual chunks of +/// strings given an input stream of token_ids +pub struct DecodeStream { + /// The tokenizer used to decode token_ids + tokenizer: Arc, + + skip_special_tokens: bool, + + /// A temporary buffer of the necessary token_ids needed + /// to produce valid string chunks + all_token_ids: Vec, + + prefix_offset: usize, + read_offset: usize, +} + +impl DecodeStream { + pub fn new( + tokenizer: Arc, + prompt_token_ids: &[u32], + skip_special_tokens: bool, + ) -> Self { + let num_input_tokens = prompt_token_ids.len(); + let prompt_token_ids = prompt_token_ids.to_vec(); + Self { + tokenizer, + skip_special_tokens, + all_token_ids: prompt_token_ids, + prefix_offset: num_input_tokens + .saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET), + read_offset: num_input_tokens, + } + } + + /// Step appends a token_id to the internal state and tries to produce a text chunk. + /// Returning `None` means the given id is not enough to produce a chunk. + pub fn step(&mut self, id: u32) -> Result> { + self.all_token_ids.push(id); + + let prefix_text = self.tokenizer.decode( + &self.all_token_ids[self.prefix_offset..self.read_offset], + self.skip_special_tokens, + )?; + + let new_text = self.tokenizer.decode( + &self.all_token_ids[self.prefix_offset..], + self.skip_special_tokens, + )?; + + if new_text.len() > prefix_text.len() && !new_text.ends_with("�") { + let new_text = new_text[prefix_text.len()..].to_string(); + + self.prefix_offset = self.read_offset; + self.read_offset = self.all_token_ids.len(); + + Ok(Some(new_text)) + } else { + Ok(None) + } + } + + /// Process multiple tokens at once + pub fn step_batch(&mut self, token_ids: &[u32]) -> Result> { + let mut chunks = Vec::new(); + + for &token_id in token_ids { + if let Some(text) = self.step(token_id)? { + chunks.push(text); + } + } + + Ok(chunks) + } + + /// Force flush any remaining text + pub fn flush(&mut self) -> Result> { + if self.read_offset < self.all_token_ids.len() { + let remaining = self.tokenizer.decode( + &self.all_token_ids[self.read_offset..], + self.skip_special_tokens, + )?; + + self.read_offset = self.all_token_ids.len(); + + if !remaining.is_empty() { + return Ok(Some(remaining)); + } + } + + Ok(None) + } + + /// Get all tokens processed so far + pub fn tokens(&self) -> &[u32] { + &self.all_token_ids + } +} diff --git a/sgl-router/src/tokenizer/tests.rs b/sgl-router/src/tokenizer/tests.rs new file mode 100644 index 000000000..2c4d4b108 --- /dev/null +++ b/sgl-router/src/tokenizer/tests.rs @@ -0,0 +1,143 @@ +#[cfg(test)] +use super::*; +#[cfg(test)] +use std::sync::Arc; + +#[test] +fn test_mock_tokenizer_encode() { + let tokenizer = mock::MockTokenizer::new(); + let encoding = tokenizer.encode("Hello world").unwrap(); + let token_ids = encoding.token_ids(); + assert_eq!(token_ids, &[1, 2]); // "Hello" -> 1, "world" -> 2 +} + +#[test] +fn test_mock_tokenizer_decode() { + let tokenizer = mock::MockTokenizer::new(); + let text = tokenizer.decode(&[1, 2], false).unwrap(); + assert_eq!(text, "Hello world"); +} + +#[test] +fn test_mock_tokenizer_decode_skip_special() { + let tokenizer = mock::MockTokenizer::new(); + + // With special tokens + let text = tokenizer.decode(&[1000, 1, 2, 999], false).unwrap(); + assert_eq!(text, " Hello world "); + + // Without special tokens + let text = tokenizer.decode(&[1000, 1, 2, 999], true).unwrap(); + assert_eq!(text, "Hello world"); +} + +#[test] +fn test_tokenizer_wrapper() { + let mock_tokenizer = Arc::new(mock::MockTokenizer::new()); + let tokenizer = Tokenizer::from_arc(mock_tokenizer); + + // Test encoding + let encoding = tokenizer.encode("Hello world").unwrap(); + assert_eq!(encoding.token_ids(), &[1, 2]); + + // Test decoding + let text = tokenizer.decode(&[1, 2], false).unwrap(); + assert_eq!(text, "Hello world"); + + // Test vocab size + assert_eq!(tokenizer.vocab_size(), 8); + + // Test token to ID + assert_eq!(tokenizer.token_to_id("Hello"), Some(1)); + assert_eq!(tokenizer.token_to_id("unknown"), None); + + // Test ID to token + assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string())); + assert_eq!(tokenizer.id_to_token(9999), None); +} + +#[test] +fn test_decode_stream_basic() { + let mock_tokenizer = Arc::new(mock::MockTokenizer::new()); + let tokenizer = Tokenizer::from_arc(mock_tokenizer); + + // Create a decode stream with initial tokens + let initial_tokens = vec![1, 2]; // "Hello world" + let mut stream = tokenizer.decode_stream(&initial_tokens, false); + + // Add a new token + let result = stream.step(3).unwrap(); // "test" + // Since we're using a mock, the actual incremental behavior depends on implementation + // For now, we just verify it doesn't crash + assert!(result.is_some() || result.is_none()); +} + +#[test] +fn test_decode_stream_flush() { + let mock_tokenizer = Arc::new(mock::MockTokenizer::new()); + let tokenizer = Tokenizer::from_arc(mock_tokenizer); + + let initial_tokens = vec![1]; + let mut stream = tokenizer.decode_stream(&initial_tokens, false); + + // Add tokens + stream.step(2).unwrap(); + stream.step(3).unwrap(); + + // Flush remaining + let flushed = stream.flush().unwrap(); + // The flush behavior depends on the implementation + assert!(flushed.is_some() || flushed.is_none()); +} + +#[test] +fn test_special_tokens() { + let mock_tokenizer = Arc::new(mock::MockTokenizer::new()); + let tokenizer = Tokenizer::from_arc(mock_tokenizer); + + let special_tokens = tokenizer.get_special_tokens(); + assert_eq!(special_tokens.bos_token, Some("".to_string())); + assert_eq!(special_tokens.eos_token, Some("".to_string())); + assert_eq!(special_tokens.unk_token, Some("".to_string())); + assert!(special_tokens.sep_token.is_none()); + assert!(special_tokens.pad_token.is_none()); +} + +#[test] +fn test_batch_encode() { + let tokenizer = mock::MockTokenizer::new(); + let inputs = vec!["Hello", "world", "test"]; + let encodings = tokenizer.encode_batch(&inputs).unwrap(); + + assert_eq!(encodings.len(), 3); + assert_eq!(encodings[0].token_ids(), &[1]); // "Hello" -> 1 + assert_eq!(encodings[1].token_ids(), &[2]); // "world" -> 2 + assert_eq!(encodings[2].token_ids(), &[3]); // "test" -> 3 +} + +#[test] +fn test_thread_safety() { + use std::thread; + + let mock_tokenizer = Arc::new(mock::MockTokenizer::new()); + let tokenizer = Tokenizer::from_arc(mock_tokenizer); + + // Spawn multiple threads that use the same tokenizer + let handles: Vec<_> = (0..10) + .map(|i| { + let tokenizer_clone = tokenizer.clone(); + thread::spawn(move || { + let text = "Hello test".to_string(); + let encoding = tokenizer_clone.encode(&text).unwrap(); + let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap(); + assert!(decoded.contains("Hello") || decoded.contains("test")); + i + }) + }) + .collect(); + + // Wait for all threads to complete + for handle in handles { + handle.join().unwrap(); + } +} diff --git a/sgl-router/src/tokenizer/traits.rs b/sgl-router/src/tokenizer/traits.rs new file mode 100644 index 000000000..54e683497 --- /dev/null +++ b/sgl-router/src/tokenizer/traits.rs @@ -0,0 +1,50 @@ +use anyhow::Result; + +/// Core encoding trait - separate from decoding for modularity +pub trait Encoder: Send + Sync { + fn encode(&self, input: &str) -> Result; + fn encode_batch(&self, inputs: &[&str]) -> Result>; +} + +/// Core decoding trait - can be implemented independently +pub trait Decoder: Send + Sync { + fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result; +} + +/// Combined tokenizer trait +pub trait Tokenizer: Encoder + Decoder { + fn vocab_size(&self) -> usize; + fn get_special_tokens(&self) -> &SpecialTokens; + fn token_to_id(&self, token: &str) -> Option; + fn id_to_token(&self, id: u32) -> Option; +} + +/// Contains the results of tokenizing text: token IDs, string tokens, and their spans +#[derive(Debug, Clone)] +pub enum Encoding { + /// Hugging Face + Hf(Box), + /// Sentence Piece + Sp(Vec), +} + +impl Encoding { + pub fn token_ids(&self) -> &[u32] { + match self { + Encoding::Hf(inner) => inner.get_ids(), + Encoding::Sp(inner) => inner, + } + } +} + +#[derive(Debug, Clone)] +pub struct SpecialTokens { + pub bos_token: Option, + pub eos_token: Option, + pub unk_token: Option, + pub sep_token: Option, + pub pad_token: Option, + pub cls_token: Option, + pub mask_token: Option, + pub additional_special_tokens: Vec, +}