[router] introducing tokenizer trait (#9287)
This commit is contained in:
@@ -43,6 +43,8 @@ uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
thiserror = "2.0.12"
|
||||
url = "2.5.4"
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
anyhow = "1.0"
|
||||
tokenizers = "0.21.4"
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
||||
@@ -10,6 +10,7 @@ pub mod policies;
|
||||
pub mod routers;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
pub mod tokenizer;
|
||||
pub mod tree;
|
||||
use crate::metrics::PrometheusConfig;
|
||||
|
||||
|
||||
112
sgl-router/src/tokenizer/mock.rs
Normal file
112
sgl-router/src/tokenizer/mock.rs
Normal file
@@ -0,0 +1,112 @@
|
||||
//! Mock tokenizer implementation for testing
|
||||
|
||||
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
use anyhow::Result;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Mock tokenizer for testing purposes
|
||||
pub struct MockTokenizer {
|
||||
vocab: HashMap<String, u32>,
|
||||
reverse_vocab: HashMap<u32, String>,
|
||||
special_tokens: SpecialTokens,
|
||||
}
|
||||
|
||||
impl Default for MockTokenizer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl MockTokenizer {
|
||||
pub fn new() -> Self {
|
||||
let mut vocab = HashMap::new();
|
||||
let mut reverse_vocab = HashMap::new();
|
||||
|
||||
// Add some basic tokens
|
||||
let tokens = vec![
|
||||
("Hello", 1),
|
||||
("world", 2),
|
||||
("test", 3),
|
||||
("token", 4),
|
||||
(" ", 5),
|
||||
(".", 6),
|
||||
("<eos>", 999),
|
||||
("<bos>", 1000),
|
||||
];
|
||||
|
||||
for (token, id) in tokens {
|
||||
vocab.insert(token.to_string(), id);
|
||||
reverse_vocab.insert(id, token.to_string());
|
||||
}
|
||||
|
||||
let special_tokens = SpecialTokens {
|
||||
bos_token: Some("<bos>".to_string()),
|
||||
eos_token: Some("<eos>".to_string()),
|
||||
unk_token: Some("<unk>".to_string()),
|
||||
sep_token: None,
|
||||
pad_token: None,
|
||||
cls_token: None,
|
||||
mask_token: None,
|
||||
additional_special_tokens: vec![],
|
||||
};
|
||||
|
||||
Self {
|
||||
vocab,
|
||||
reverse_vocab,
|
||||
special_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder for MockTokenizer {
|
||||
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
// Simple word-based tokenization for testing
|
||||
let tokens: Vec<u32> = input
|
||||
.split_whitespace()
|
||||
.filter_map(|word| self.vocab.get(word).copied())
|
||||
.collect();
|
||||
|
||||
Ok(Encoding::Sp(tokens))
|
||||
}
|
||||
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
inputs.iter().map(|input| self.encode(input)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for MockTokenizer {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
let tokens: Vec<String> = token_ids
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
self.reverse_vocab.get(id).and_then(|token| {
|
||||
if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
|
||||
None
|
||||
} else {
|
||||
Some(token.clone())
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(tokens.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenizerTrait for MockTokenizer {
|
||||
fn vocab_size(&self) -> usize {
|
||||
self.vocab.len()
|
||||
}
|
||||
|
||||
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
&self.special_tokens
|
||||
}
|
||||
|
||||
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.vocab.get(token).copied()
|
||||
}
|
||||
|
||||
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.reverse_vocab.get(&id).cloned()
|
||||
}
|
||||
}
|
||||
89
sgl-router/src/tokenizer/mod.rs
Normal file
89
sgl-router/src/tokenizer/mod.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use anyhow::Result;
|
||||
use std::ops::Deref;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub mod mock;
|
||||
pub mod stream;
|
||||
pub mod traits;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use stream::DecodeStream;
|
||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||
|
||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||
#[derive(Clone)]
|
||||
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
||||
|
||||
impl Tokenizer {
|
||||
/// Create a tokenizer from a file path
|
||||
/// Will be implemented in Phase 3 with factory pattern
|
||||
pub fn from_file(_file_path: &str) -> Result<Tokenizer> {
|
||||
// TODO: Implement factory pattern in Phase 3
|
||||
unimplemented!("Factory pattern will be implemented in Phase 3")
|
||||
}
|
||||
|
||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
Tokenizer(tokenizer)
|
||||
}
|
||||
|
||||
/// Create a stateful sequence object for decoding token_ids into text
|
||||
pub fn decode_stream(
|
||||
&self,
|
||||
prompt_token_ids: &[u32],
|
||||
skip_special_tokens: bool,
|
||||
) -> DecodeStream {
|
||||
DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
|
||||
}
|
||||
|
||||
/// Direct encode method
|
||||
pub fn encode(&self, input: &str) -> Result<Encoding> {
|
||||
self.0.encode(input)
|
||||
}
|
||||
|
||||
/// Direct batch encode method
|
||||
pub fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||
self.0.encode_batch(inputs)
|
||||
}
|
||||
|
||||
/// Direct decode method
|
||||
pub fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||
self.0.decode(token_ids, skip_special_tokens)
|
||||
}
|
||||
|
||||
/// Get vocabulary size
|
||||
pub fn vocab_size(&self) -> usize {
|
||||
self.0.vocab_size()
|
||||
}
|
||||
|
||||
/// Get special tokens
|
||||
pub fn get_special_tokens(&self) -> &SpecialTokens {
|
||||
self.0.get_special_tokens()
|
||||
}
|
||||
|
||||
/// Convert token string to ID
|
||||
pub fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||
self.0.token_to_id(token)
|
||||
}
|
||||
|
||||
/// Convert ID to token string
|
||||
pub fn id_to_token(&self, id: u32) -> Option<String> {
|
||||
self.0.id_to_token(id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Tokenizer {
|
||||
type Target = Arc<dyn traits::Tokenizer>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
|
||||
fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||
Tokenizer(tokenizer)
|
||||
}
|
||||
}
|
||||
105
sgl-router/src/tokenizer/stream.rs
Normal file
105
sgl-router/src/tokenizer/stream.rs
Normal file
@@ -0,0 +1,105 @@
|
||||
// src/tokenizer/stream.rs
|
||||
|
||||
use super::traits;
|
||||
use anyhow::Result;
|
||||
use std::sync::Arc;
|
||||
|
||||
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
|
||||
|
||||
/// DecodeStream will keep the state necessary to produce individual chunks of
|
||||
/// strings given an input stream of token_ids
|
||||
pub struct DecodeStream {
|
||||
/// The tokenizer used to decode token_ids
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
|
||||
skip_special_tokens: bool,
|
||||
|
||||
/// A temporary buffer of the necessary token_ids needed
|
||||
/// to produce valid string chunks
|
||||
all_token_ids: Vec<u32>,
|
||||
|
||||
prefix_offset: usize,
|
||||
read_offset: usize,
|
||||
}
|
||||
|
||||
impl DecodeStream {
|
||||
pub fn new(
|
||||
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||
prompt_token_ids: &[u32],
|
||||
skip_special_tokens: bool,
|
||||
) -> Self {
|
||||
let num_input_tokens = prompt_token_ids.len();
|
||||
let prompt_token_ids = prompt_token_ids.to_vec();
|
||||
Self {
|
||||
tokenizer,
|
||||
skip_special_tokens,
|
||||
all_token_ids: prompt_token_ids,
|
||||
prefix_offset: num_input_tokens
|
||||
.saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
|
||||
read_offset: num_input_tokens,
|
||||
}
|
||||
}
|
||||
|
||||
/// Step appends a token_id to the internal state and tries to produce a text chunk.
|
||||
/// Returning `None` means the given id is not enough to produce a chunk.
|
||||
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
|
||||
self.all_token_ids.push(id);
|
||||
|
||||
let prefix_text = self.tokenizer.decode(
|
||||
&self.all_token_ids[self.prefix_offset..self.read_offset],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
let new_text = self.tokenizer.decode(
|
||||
&self.all_token_ids[self.prefix_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
if new_text.len() > prefix_text.len() && !new_text.ends_with("<EFBFBD>") {
|
||||
let new_text = new_text[prefix_text.len()..].to_string();
|
||||
|
||||
self.prefix_offset = self.read_offset;
|
||||
self.read_offset = self.all_token_ids.len();
|
||||
|
||||
Ok(Some(new_text))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Process multiple tokens at once
|
||||
pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
|
||||
let mut chunks = Vec::new();
|
||||
|
||||
for &token_id in token_ids {
|
||||
if let Some(text) = self.step(token_id)? {
|
||||
chunks.push(text);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(chunks)
|
||||
}
|
||||
|
||||
/// Force flush any remaining text
|
||||
pub fn flush(&mut self) -> Result<Option<String>> {
|
||||
if self.read_offset < self.all_token_ids.len() {
|
||||
let remaining = self.tokenizer.decode(
|
||||
&self.all_token_ids[self.read_offset..],
|
||||
self.skip_special_tokens,
|
||||
)?;
|
||||
|
||||
self.read_offset = self.all_token_ids.len();
|
||||
|
||||
if !remaining.is_empty() {
|
||||
return Ok(Some(remaining));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get all tokens processed so far
|
||||
pub fn tokens(&self) -> &[u32] {
|
||||
&self.all_token_ids
|
||||
}
|
||||
}
|
||||
143
sgl-router/src/tokenizer/tests.rs
Normal file
143
sgl-router/src/tokenizer/tests.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
#[cfg(test)]
|
||||
use super::*;
|
||||
#[cfg(test)]
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_encode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
let encoding = tokenizer.encode("Hello world").unwrap();
|
||||
let token_ids = encoding.token_ids();
|
||||
assert_eq!(token_ids, &[1, 2]); // "Hello" -> 1, "world" -> 2
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_decode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mock_tokenizer_decode_skip_special() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
|
||||
// With special tokens
|
||||
let text = tokenizer.decode(&[1000, 1, 2, 999], false).unwrap();
|
||||
assert_eq!(text, "<bos> Hello world <eos>");
|
||||
|
||||
// Without special tokens
|
||||
let text = tokenizer.decode(&[1000, 1, 2, 999], true).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_wrapper() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Test encoding
|
||||
let encoding = tokenizer.encode("Hello world").unwrap();
|
||||
assert_eq!(encoding.token_ids(), &[1, 2]);
|
||||
|
||||
// Test decoding
|
||||
let text = tokenizer.decode(&[1, 2], false).unwrap();
|
||||
assert_eq!(text, "Hello world");
|
||||
|
||||
// Test vocab size
|
||||
assert_eq!(tokenizer.vocab_size(), 8);
|
||||
|
||||
// Test token to ID
|
||||
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
|
||||
assert_eq!(tokenizer.token_to_id("unknown"), None);
|
||||
|
||||
// Test ID to token
|
||||
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
|
||||
assert_eq!(tokenizer.id_to_token(9999), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_stream_basic() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Create a decode stream with initial tokens
|
||||
let initial_tokens = vec![1, 2]; // "Hello world"
|
||||
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
|
||||
|
||||
// Add a new token
|
||||
let result = stream.step(3).unwrap(); // "test"
|
||||
// Since we're using a mock, the actual incremental behavior depends on implementation
|
||||
// For now, we just verify it doesn't crash
|
||||
assert!(result.is_some() || result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_stream_flush() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
let initial_tokens = vec![1];
|
||||
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
|
||||
|
||||
// Add tokens
|
||||
stream.step(2).unwrap();
|
||||
stream.step(3).unwrap();
|
||||
|
||||
// Flush remaining
|
||||
let flushed = stream.flush().unwrap();
|
||||
// The flush behavior depends on the implementation
|
||||
assert!(flushed.is_some() || flushed.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_tokens() {
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
let special_tokens = tokenizer.get_special_tokens();
|
||||
assert_eq!(special_tokens.bos_token, Some("<bos>".to_string()));
|
||||
assert_eq!(special_tokens.eos_token, Some("<eos>".to_string()));
|
||||
assert_eq!(special_tokens.unk_token, Some("<unk>".to_string()));
|
||||
assert!(special_tokens.sep_token.is_none());
|
||||
assert!(special_tokens.pad_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_encode() {
|
||||
let tokenizer = mock::MockTokenizer::new();
|
||||
let inputs = vec!["Hello", "world", "test"];
|
||||
let encodings = tokenizer.encode_batch(&inputs).unwrap();
|
||||
|
||||
assert_eq!(encodings.len(), 3);
|
||||
assert_eq!(encodings[0].token_ids(), &[1]); // "Hello" -> 1
|
||||
assert_eq!(encodings[1].token_ids(), &[2]); // "world" -> 2
|
||||
assert_eq!(encodings[2].token_ids(), &[3]); // "test" -> 3
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thread_safety() {
|
||||
use std::thread;
|
||||
|
||||
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
|
||||
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
|
||||
|
||||
// Spawn multiple threads that use the same tokenizer
|
||||
let handles: Vec<_> = (0..10)
|
||||
.map(|i| {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
thread::spawn(move || {
|
||||
let text = "Hello test".to_string();
|
||||
let encoding = tokenizer_clone.encode(&text).unwrap();
|
||||
let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
|
||||
assert!(decoded.contains("Hello") || decoded.contains("test"));
|
||||
i
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Wait for all threads to complete
|
||||
for handle in handles {
|
||||
handle.join().unwrap();
|
||||
}
|
||||
}
|
||||
50
sgl-router/src/tokenizer/traits.rs
Normal file
50
sgl-router/src/tokenizer/traits.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use anyhow::Result;
|
||||
|
||||
/// Core encoding trait - separate from decoding for modularity
|
||||
pub trait Encoder: Send + Sync {
|
||||
fn encode(&self, input: &str) -> Result<Encoding>;
|
||||
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
|
||||
}
|
||||
|
||||
/// Core decoding trait - can be implemented independently
|
||||
pub trait Decoder: Send + Sync {
|
||||
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
|
||||
}
|
||||
|
||||
/// Combined tokenizer trait
|
||||
pub trait Tokenizer: Encoder + Decoder {
|
||||
fn vocab_size(&self) -> usize;
|
||||
fn get_special_tokens(&self) -> &SpecialTokens;
|
||||
fn token_to_id(&self, token: &str) -> Option<u32>;
|
||||
fn id_to_token(&self, id: u32) -> Option<String>;
|
||||
}
|
||||
|
||||
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Encoding {
|
||||
/// Hugging Face
|
||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||
/// Sentence Piece
|
||||
Sp(Vec<u32>),
|
||||
}
|
||||
|
||||
impl Encoding {
|
||||
pub fn token_ids(&self) -> &[u32] {
|
||||
match self {
|
||||
Encoding::Hf(inner) => inner.get_ids(),
|
||||
Encoding::Sp(inner) => inner,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SpecialTokens {
|
||||
pub bos_token: Option<String>,
|
||||
pub eos_token: Option<String>,
|
||||
pub unk_token: Option<String>,
|
||||
pub sep_token: Option<String>,
|
||||
pub pad_token: Option<String>,
|
||||
pub cls_token: Option<String>,
|
||||
pub mask_token: Option<String>,
|
||||
pub additional_special_tokens: Vec<String>,
|
||||
}
|
||||
Reference in New Issue
Block a user