[router] tokenizer factory, hf tokenizer, and stop sequence detector (#9293)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -3,6 +3,10 @@ name = "sglang_router_rs"
|
|||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["huggingface"]
|
||||||
|
huggingface = ["tokenizers"]
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "sglang_router_rs"
|
name = "sglang_router_rs"
|
||||||
# Pure Rust library: Just omit crate-type (defaults to rlib)
|
# Pure Rust library: Just omit crate-type (defaults to rlib)
|
||||||
@@ -44,7 +48,7 @@ thiserror = "2.0.12"
|
|||||||
url = "2.5.4"
|
url = "2.5.4"
|
||||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
tokenizers = "0.21.4"
|
tokenizers = { version = "0.21.4", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|||||||
228
sgl-router/src/tokenizer/factory.rs
Normal file
228
sgl-router/src/tokenizer/factory.rs
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
use super::traits;
|
||||||
|
use anyhow::{Error, Result};
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
use super::huggingface::HuggingFaceTokenizer;
|
||||||
|
|
||||||
|
/// Represents the type of tokenizer being used
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum TokenizerType {
|
||||||
|
HuggingFace(String),
|
||||||
|
Mock,
|
||||||
|
// Future: SentencePiece, GGUF, Tiktoken
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a tokenizer from a file path to a tokenizer file.
|
||||||
|
/// The file extension is used to determine the tokenizer type.
|
||||||
|
/// Supported file types are:
|
||||||
|
/// - json: HuggingFace tokenizer
|
||||||
|
/// - For testing: can return mock tokenizer
|
||||||
|
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
|
// Special case for testing
|
||||||
|
if file_path == "mock" || file_path == "test" {
|
||||||
|
return Ok(Arc::new(super::mock::MockTokenizer::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = Path::new(file_path);
|
||||||
|
|
||||||
|
// Check if file exists
|
||||||
|
if !path.exists() {
|
||||||
|
return Err(Error::msg(format!("File not found: {}", file_path)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to determine tokenizer type from extension
|
||||||
|
let extension = path
|
||||||
|
.extension()
|
||||||
|
.and_then(std::ffi::OsStr::to_str)
|
||||||
|
.map(|s| s.to_lowercase());
|
||||||
|
|
||||||
|
match extension.as_deref() {
|
||||||
|
Some("json") => {
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
{
|
||||||
|
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||||
|
Ok(Arc::new(tokenizer))
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "huggingface"))]
|
||||||
|
{
|
||||||
|
Err(Error::msg(
|
||||||
|
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some("model") => {
|
||||||
|
// SentencePiece model file
|
||||||
|
Err(Error::msg("SentencePiece models not yet supported"))
|
||||||
|
}
|
||||||
|
Some("gguf") => {
|
||||||
|
// GGUF format
|
||||||
|
Err(Error::msg("GGUF format not yet supported"))
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Try to auto-detect by reading file content
|
||||||
|
auto_detect_tokenizer(file_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Auto-detect tokenizer type by examining file content
|
||||||
|
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
|
let mut file = File::open(file_path)?;
|
||||||
|
let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
|
||||||
|
let bytes_read = file.read(&mut buffer)?;
|
||||||
|
buffer.truncate(bytes_read);
|
||||||
|
|
||||||
|
// Check for JSON (HuggingFace format)
|
||||||
|
if is_likely_json(&buffer) {
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
{
|
||||||
|
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||||
|
return Ok(Arc::new(tokenizer));
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "huggingface"))]
|
||||||
|
{
|
||||||
|
return Err(Error::msg(
|
||||||
|
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for GGUF magic number
|
||||||
|
if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
|
||||||
|
return Err(Error::msg("GGUF format detected but not yet supported"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for SentencePiece model
|
||||||
|
if is_likely_sentencepiece(&buffer) {
|
||||||
|
return Err(Error::msg(
|
||||||
|
"SentencePiece model detected but not yet supported",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(Error::msg(format!(
|
||||||
|
"Unable to determine tokenizer type for file: {}",
|
||||||
|
file_path
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the buffer likely contains JSON data
|
||||||
|
fn is_likely_json(buffer: &[u8]) -> bool {
|
||||||
|
// Skip UTF-8 BOM if present
|
||||||
|
let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
|
||||||
|
&buffer[3..]
|
||||||
|
} else {
|
||||||
|
buffer
|
||||||
|
};
|
||||||
|
|
||||||
|
// Find first non-whitespace character without allocation
|
||||||
|
if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
|
||||||
|
*first_byte == b'{' || *first_byte == b'['
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the buffer likely contains a SentencePiece model
|
||||||
|
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|
||||||
|
// SentencePiece models often start with specific patterns
|
||||||
|
// This is a simplified check
|
||||||
|
buffer.len() >= 12
|
||||||
|
&& (buffer.starts_with(b"\x0a\x09")
|
||||||
|
|| buffer.starts_with(b"\x08\x00")
|
||||||
|
|| buffer.windows(4).any(|w| w == b"<unk")
|
||||||
|
|| buffer.windows(4).any(|w| w == b"<s>")
|
||||||
|
|| buffer.windows(4).any(|w| w == b"</s>"))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Factory function to create tokenizer from a model name or path
|
||||||
|
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
|
// Check if it's a file path
|
||||||
|
let path = Path::new(model_name_or_path);
|
||||||
|
if path.exists() {
|
||||||
|
return create_tokenizer_from_file(model_name_or_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, try to load from HuggingFace Hub
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
{
|
||||||
|
// This would download from HF Hub - not implemented yet
|
||||||
|
Err(Error::msg(
|
||||||
|
"Loading from HuggingFace Hub not yet implemented",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "huggingface"))]
|
||||||
|
{
|
||||||
|
Err(Error::msg(format!(
|
||||||
|
"Model '{}' not found locally and HuggingFace support is not enabled",
|
||||||
|
model_name_or_path
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get information about a tokenizer file
|
||||||
|
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
|
||||||
|
let path = Path::new(file_path);
|
||||||
|
|
||||||
|
if !path.exists() {
|
||||||
|
return Err(Error::msg(format!("File not found: {}", file_path)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let extension = path
|
||||||
|
.extension()
|
||||||
|
.and_then(std::ffi::OsStr::to_str)
|
||||||
|
.map(|s| s.to_lowercase());
|
||||||
|
|
||||||
|
match extension.as_deref() {
|
||||||
|
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
|
||||||
|
_ => {
|
||||||
|
// Try auto-detection
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
|
let mut file = File::open(file_path)?;
|
||||||
|
let mut buffer = vec![0u8; 512];
|
||||||
|
let bytes_read = file.read(&mut buffer)?;
|
||||||
|
buffer.truncate(bytes_read);
|
||||||
|
|
||||||
|
if is_likely_json(&buffer) {
|
||||||
|
Ok(TokenizerType::HuggingFace(file_path.to_string()))
|
||||||
|
} else {
|
||||||
|
Err(Error::msg("Unknown tokenizer type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_json_detection() {
|
||||||
|
assert!(is_likely_json(b"{\"test\": \"value\"}"));
|
||||||
|
assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
|
||||||
|
assert!(is_likely_json(b"[1, 2, 3]"));
|
||||||
|
assert!(!is_likely_json(b"not json"));
|
||||||
|
assert!(!is_likely_json(b""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_mock_tokenizer_creation() {
|
||||||
|
let tokenizer = create_tokenizer_from_file("mock").unwrap();
|
||||||
|
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_file_not_found() {
|
||||||
|
let result = create_tokenizer_from_file("/nonexistent/file.json");
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
assert!(e.to_string().contains("File not found"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
189
sgl-router/src/tokenizer/huggingface.rs
Normal file
189
sgl-router/src/tokenizer/huggingface.rs
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||||
|
use anyhow::{Error, Result};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||||
|
|
||||||
|
/// HuggingFace tokenizer wrapper
|
||||||
|
pub struct HuggingFaceTokenizer {
|
||||||
|
tokenizer: HfTokenizer,
|
||||||
|
special_tokens: SpecialTokens,
|
||||||
|
vocab: HashMap<String, u32>,
|
||||||
|
reverse_vocab: HashMap<u32, String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HuggingFaceTokenizer {
|
||||||
|
/// Create a tokenizer from a HuggingFace tokenizer JSON file
|
||||||
|
pub fn from_file(file_path: &str) -> Result<Self> {
|
||||||
|
let tokenizer = HfTokenizer::from_file(file_path)
|
||||||
|
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
|
||||||
|
|
||||||
|
// Extract special tokens
|
||||||
|
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||||
|
|
||||||
|
// Build vocab mappings
|
||||||
|
let vocab = tokenizer.get_vocab(false);
|
||||||
|
let reverse_vocab: HashMap<u32, String> = vocab
|
||||||
|
.iter()
|
||||||
|
.map(|(token, &id)| (id, token.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(HuggingFaceTokenizer {
|
||||||
|
tokenizer,
|
||||||
|
special_tokens,
|
||||||
|
vocab,
|
||||||
|
reverse_vocab,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create from an existing HuggingFace tokenizer
|
||||||
|
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
|
||||||
|
let special_tokens = Self::extract_special_tokens(&tokenizer);
|
||||||
|
let vocab = tokenizer.get_vocab(false);
|
||||||
|
let reverse_vocab: HashMap<u32, String> = vocab
|
||||||
|
.iter()
|
||||||
|
.map(|(token, &id)| (id, token.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
HuggingFaceTokenizer {
|
||||||
|
tokenizer,
|
||||||
|
special_tokens,
|
||||||
|
vocab,
|
||||||
|
reverse_vocab,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract special tokens from the tokenizer
|
||||||
|
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
|
||||||
|
// Try to get special tokens from the tokenizer
|
||||||
|
// This is a simplified version - actual implementation would need to handle various formats
|
||||||
|
let vocab = tokenizer.get_vocab(true);
|
||||||
|
|
||||||
|
let find_token = |patterns: &[&str]| -> Option<String> {
|
||||||
|
for pattern in patterns {
|
||||||
|
if vocab.contains_key(*pattern) {
|
||||||
|
return Some(pattern.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
SpecialTokens {
|
||||||
|
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
|
||||||
|
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
|
||||||
|
unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
|
||||||
|
sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
|
||||||
|
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
|
||||||
|
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
|
||||||
|
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
|
||||||
|
additional_special_tokens: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply chat template if available
|
||||||
|
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
|
||||||
|
// This is a placeholder - actual implementation would handle templates
|
||||||
|
let mut result = String::new();
|
||||||
|
for msg in messages {
|
||||||
|
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
||||||
|
}
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder for HuggingFaceTokenizer {
|
||||||
|
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||||
|
let encoding = self
|
||||||
|
.tokenizer
|
||||||
|
.encode(input, false)
|
||||||
|
.map_err(|e| Error::msg(format!("Encoding failed: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(Encoding::Hf(Box::new(encoding)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||||
|
let encodings = self
|
||||||
|
.tokenizer
|
||||||
|
.encode_batch(inputs.to_vec(), false)
|
||||||
|
.map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
|
||||||
|
|
||||||
|
Ok(encodings
|
||||||
|
.into_iter()
|
||||||
|
.map(|e| Encoding::Hf(Box::new(e)))
|
||||||
|
.collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for HuggingFaceTokenizer {
|
||||||
|
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
|
||||||
|
self.tokenizer
|
||||||
|
.decode(token_ids, skip_special_tokens)
|
||||||
|
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenizerTrait for HuggingFaceTokenizer {
|
||||||
|
fn vocab_size(&self) -> usize {
|
||||||
|
self.tokenizer.get_vocab_size(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||||
|
&self.special_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_to_id(&self, token: &str) -> Option<u32> {
|
||||||
|
self.vocab.get(token).copied()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn id_to_token(&self, id: u32) -> Option<String> {
|
||||||
|
self.reverse_vocab.get(&id).cloned()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Represents a chat message for template application
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatMessage {
|
||||||
|
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
|
||||||
|
ChatMessage {
|
||||||
|
role: role.into(),
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn system(content: impl Into<String>) -> Self {
|
||||||
|
Self::new("system", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user(content: impl Into<String>) -> Self {
|
||||||
|
Self::new("user", content)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant(content: impl Into<String>) -> Self {
|
||||||
|
Self::new("assistant", content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_creation() {
|
||||||
|
let msg = ChatMessage::system("You are a helpful assistant");
|
||||||
|
assert_eq!(msg.role, "system");
|
||||||
|
assert_eq!(msg.content, "You are a helpful assistant");
|
||||||
|
|
||||||
|
let user_msg = ChatMessage::user("Hello!");
|
||||||
|
assert_eq!(user_msg.role, "user");
|
||||||
|
|
||||||
|
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||||
|
assert_eq!(assistant_msg.role, "assistant");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: Actual tokenizer tests would require a real tokenizer file
|
||||||
|
// These would be integration tests rather than unit tests
|
||||||
|
}
|
||||||
@@ -2,26 +2,36 @@ use anyhow::Result;
|
|||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
pub mod factory;
|
||||||
pub mod mock;
|
pub mod mock;
|
||||||
|
pub mod stop;
|
||||||
pub mod stream;
|
pub mod stream;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
|
// Feature-gated modules
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
pub mod huggingface;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
|
// Re-exports
|
||||||
|
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
||||||
|
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||||
pub use stream::DecodeStream;
|
pub use stream::DecodeStream;
|
||||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||||
|
|
||||||
|
#[cfg(feature = "huggingface")]
|
||||||
|
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
||||||
|
|
||||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
||||||
|
|
||||||
impl Tokenizer {
|
impl Tokenizer {
|
||||||
/// Create a tokenizer from a file path
|
/// Create a tokenizer from a file path
|
||||||
/// Will be implemented in Phase 3 with factory pattern
|
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
|
||||||
pub fn from_file(_file_path: &str) -> Result<Tokenizer> {
|
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
|
||||||
// TODO: Implement factory pattern in Phase 3
|
|
||||||
unimplemented!("Factory pattern will be implemented in Phase 3")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
/// Create a tokenizer from an Arc<dyn Tokenizer>
|
||||||
|
|||||||
499
sgl-router/src/tokenizer/stop.rs
Normal file
499
sgl-router/src/tokenizer/stop.rs
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
use super::traits;
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Output from the sequence decoder
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum SequenceDecoderOutput {
|
||||||
|
/// Normal text output
|
||||||
|
Text(String),
|
||||||
|
/// Text is being held due to partial stop sequence match
|
||||||
|
Held,
|
||||||
|
/// Stop sequence matched (hidden - not included in output)
|
||||||
|
Stopped,
|
||||||
|
/// Stop sequence matched with text (visible - included in output)
|
||||||
|
StoppedWithText(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Configuration for stop sequences
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
pub struct StopSequenceConfig {
|
||||||
|
/// Token IDs that trigger a stop
|
||||||
|
pub stop_tokens: HashSet<u32>,
|
||||||
|
/// String sequences that trigger a stop
|
||||||
|
pub stop_sequences: Vec<String>,
|
||||||
|
/// Token IDs for visible stops (included in output)
|
||||||
|
pub visible_stop_tokens: HashSet<u32>,
|
||||||
|
/// String sequences for visible stops (included in output)
|
||||||
|
pub visible_stop_sequences: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StopSequenceConfig {
|
||||||
|
/// Builder pattern - add a stop token
|
||||||
|
pub fn with_stop_token(mut self, token_id: u32) -> Self {
|
||||||
|
self.stop_tokens.insert(token_id);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builder pattern - add a stop sequence
|
||||||
|
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builder pattern - add a visible stop token
|
||||||
|
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
|
||||||
|
self.visible_stop_tokens.insert(token_id);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builder pattern - add a visible stop sequence
|
||||||
|
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.visible_stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decoder that handles stop sequences
|
||||||
|
pub struct StopSequenceDecoder {
|
||||||
|
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||||
|
config: StopSequenceConfig,
|
||||||
|
/// Buffer for partial matches (the "jail")
|
||||||
|
jail_buffer: String,
|
||||||
|
/// Accumulated tokens
|
||||||
|
token_buffer: Vec<u32>,
|
||||||
|
/// Offset where the prefix text starts (for context)
|
||||||
|
prefix_offset: usize,
|
||||||
|
/// Offset marking the end of previously decoded text
|
||||||
|
read_offset: usize,
|
||||||
|
/// Whether we've stopped
|
||||||
|
stopped: bool,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StopSequenceDecoder {
|
||||||
|
/// Create a new stop sequence decoder
|
||||||
|
pub fn new(
|
||||||
|
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||||
|
config: StopSequenceConfig,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
) -> Self {
|
||||||
|
StopSequenceDecoder {
|
||||||
|
tokenizer,
|
||||||
|
config,
|
||||||
|
jail_buffer: String::new(),
|
||||||
|
token_buffer: Vec::new(),
|
||||||
|
prefix_offset: 0,
|
||||||
|
read_offset: 0,
|
||||||
|
stopped: false,
|
||||||
|
skip_special_tokens,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a single token
|
||||||
|
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
|
||||||
|
if self.stopped {
|
||||||
|
return Ok(SequenceDecoderOutput::Stopped);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for token-level stops first
|
||||||
|
if self.config.stop_tokens.contains(&token_id) {
|
||||||
|
self.stopped = true;
|
||||||
|
// Flush any jailed text before stopping
|
||||||
|
if !self.jail_buffer.is_empty() {
|
||||||
|
let output = self.jail_buffer.clone();
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||||
|
}
|
||||||
|
return Ok(SequenceDecoderOutput::Stopped);
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.visible_stop_tokens.contains(&token_id) {
|
||||||
|
self.stopped = true;
|
||||||
|
// Include jailed text plus the stop token
|
||||||
|
let stop_text = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&[token_id], self.skip_special_tokens)?;
|
||||||
|
let output = format!("{}{}", self.jail_buffer, stop_text);
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add token to buffer
|
||||||
|
self.token_buffer.push(token_id);
|
||||||
|
|
||||||
|
// Use incremental decoding like DecodeStream
|
||||||
|
// First decode the previous context (what we've already output)
|
||||||
|
let prefix_text = if self.read_offset > self.prefix_offset {
|
||||||
|
self.tokenizer.decode(
|
||||||
|
&self.token_buffer[self.prefix_offset..self.read_offset],
|
||||||
|
self.skip_special_tokens,
|
||||||
|
)?
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
// Now decode from prefix to current position
|
||||||
|
let new_full_text = self.tokenizer.decode(
|
||||||
|
&self.token_buffer[self.prefix_offset..],
|
||||||
|
self.skip_special_tokens,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Check for incomplete UTF-8 sequence
|
||||||
|
if new_full_text.ends_with("<EFBFBD>") {
|
||||||
|
// Wait for more tokens to complete the sequence
|
||||||
|
return Ok(SequenceDecoderOutput::Held);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate only the NEW text since last successful decode
|
||||||
|
let new_text = if new_full_text.len() > prefix_text.len() {
|
||||||
|
&new_full_text[prefix_text.len()..]
|
||||||
|
} else {
|
||||||
|
// No new text produced (can happen with special tokens)
|
||||||
|
return Ok(SequenceDecoderOutput::Held);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Combine jail buffer with new text for checking
|
||||||
|
let check_text = format!("{}{}", self.jail_buffer, new_text);
|
||||||
|
|
||||||
|
// Check for complete stop sequences
|
||||||
|
for stop_seq in &self.config.stop_sequences {
|
||||||
|
if let Some(pos) = check_text.find(stop_seq) {
|
||||||
|
self.stopped = true;
|
||||||
|
// Output text before the stop sequence
|
||||||
|
let output = check_text[..pos].to_string();
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
return Ok(if output.is_empty() {
|
||||||
|
SequenceDecoderOutput::Stopped
|
||||||
|
} else {
|
||||||
|
SequenceDecoderOutput::StoppedWithText(output)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for visible stop sequences
|
||||||
|
for stop_seq in &self.config.visible_stop_sequences {
|
||||||
|
if let Some(pos) = check_text.find(stop_seq) {
|
||||||
|
self.stopped = true;
|
||||||
|
// Include the stop sequence in output
|
||||||
|
let end_pos = pos + stop_seq.len();
|
||||||
|
let output = check_text[..end_pos].to_string();
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
return Ok(SequenceDecoderOutput::StoppedWithText(output));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for partial matches at the end of check_text
|
||||||
|
let mut partial_match_len = 0;
|
||||||
|
for stop_seq in self
|
||||||
|
.config
|
||||||
|
.stop_sequences
|
||||||
|
.iter()
|
||||||
|
.chain(&self.config.visible_stop_sequences)
|
||||||
|
{
|
||||||
|
// Check all possible suffixes that could be a prefix of stop_seq
|
||||||
|
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
|
||||||
|
let suffix = &check_text[check_text.len() - i..];
|
||||||
|
if stop_seq.starts_with(suffix) {
|
||||||
|
partial_match_len = partial_match_len.max(i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if partial_match_len > 0 {
|
||||||
|
// Split: output safe text, jail the potential match
|
||||||
|
let safe_end = check_text.len() - partial_match_len;
|
||||||
|
let safe_text = &check_text[..safe_end];
|
||||||
|
self.jail_buffer = check_text[safe_end..].to_string();
|
||||||
|
|
||||||
|
// Update offsets for next iteration
|
||||||
|
self.prefix_offset = self.read_offset;
|
||||||
|
self.read_offset = self.token_buffer.len();
|
||||||
|
|
||||||
|
if safe_text.is_empty() {
|
||||||
|
Ok(SequenceDecoderOutput::Held)
|
||||||
|
} else {
|
||||||
|
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No partial matches - output everything
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
|
||||||
|
// Update offsets for next iteration
|
||||||
|
self.prefix_offset = self.read_offset;
|
||||||
|
self.read_offset = self.token_buffer.len();
|
||||||
|
|
||||||
|
Ok(SequenceDecoderOutput::Text(check_text))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process multiple tokens
|
||||||
|
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
|
||||||
|
let mut outputs = Vec::new();
|
||||||
|
for &token_id in token_ids {
|
||||||
|
outputs.push(self.process_token(token_id)?);
|
||||||
|
}
|
||||||
|
Ok(outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flush any held text
|
||||||
|
pub fn flush(&mut self) -> SequenceDecoderOutput {
|
||||||
|
if !self.jail_buffer.is_empty() {
|
||||||
|
let output = self.jail_buffer.clone();
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
SequenceDecoderOutput::Text(output)
|
||||||
|
} else {
|
||||||
|
SequenceDecoderOutput::Text(String::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if decoding has stopped
|
||||||
|
pub fn is_stopped(&self) -> bool {
|
||||||
|
self.stopped
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reset the decoder state
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.jail_buffer.clear();
|
||||||
|
self.token_buffer.clear();
|
||||||
|
self.prefix_offset = 0;
|
||||||
|
self.read_offset = 0;
|
||||||
|
self.stopped = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Builder for StopSequenceDecoder
|
||||||
|
pub struct StopSequenceDecoderBuilder {
|
||||||
|
tokenizer: Arc<dyn traits::Tokenizer>,
|
||||||
|
config: StopSequenceConfig,
|
||||||
|
skip_special_tokens: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StopSequenceDecoderBuilder {
|
||||||
|
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
|
||||||
|
StopSequenceDecoderBuilder {
|
||||||
|
tokenizer,
|
||||||
|
config: StopSequenceConfig::default(),
|
||||||
|
skip_special_tokens: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stop_token(mut self, token_id: u32) -> Self {
|
||||||
|
self.config.stop_tokens.insert(token_id);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.config.stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
|
||||||
|
self.config.visible_stop_tokens.insert(token_id);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
|
||||||
|
self.config.visible_stop_sequences.push(sequence.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
|
||||||
|
self.skip_special_tokens = skip;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(self) -> StopSequenceDecoder {
|
||||||
|
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::tokenizer::mock::MockTokenizer;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stop_token_detection() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
|
||||||
|
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process tokens before stop
|
||||||
|
let result = decoder.process_token(1).unwrap(); // "Hello"
|
||||||
|
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||||
|
|
||||||
|
// Process stop token
|
||||||
|
let result = decoder.process_token(999).unwrap(); // <eos>
|
||||||
|
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||||||
|
|
||||||
|
// Further tokens should also return Stopped
|
||||||
|
let result = decoder.process_token(2).unwrap();
|
||||||
|
assert_eq!(result, SequenceDecoderOutput::Stopped);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_visible_stop_token() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default().with_visible_stop_token(999);
|
||||||
|
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
let result = decoder.process_token(999).unwrap();
|
||||||
|
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_builder_pattern() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
|
||||||
|
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
|
||||||
|
.stop_token(999)
|
||||||
|
.stop_sequence("STOP")
|
||||||
|
.visible_stop_token(1000)
|
||||||
|
.skip_special_tokens(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
assert!(!decoder.is_stopped());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_incremental_decoding_no_repetition() {
|
||||||
|
// This test verifies the critical fix: no repeated output
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default();
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process tokens one by one and collect outputs
|
||||||
|
let mut outputs = Vec::new();
|
||||||
|
|
||||||
|
// Token 1: "Hello"
|
||||||
|
let result = decoder.process_token(1).unwrap();
|
||||||
|
if let SequenceDecoderOutput::Text(text) = result {
|
||||||
|
outputs.push(text.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token 2: "world"
|
||||||
|
let result = decoder.process_token(2).unwrap();
|
||||||
|
if let SequenceDecoderOutput::Text(text) = result {
|
||||||
|
outputs.push(text.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token 3: "test"
|
||||||
|
let result = decoder.process_token(3).unwrap();
|
||||||
|
if let SequenceDecoderOutput::Text(text) = result {
|
||||||
|
outputs.push(text.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// CRITICAL: Each output should be unique (no accumulation)
|
||||||
|
// The fix ensures we only output NEW text, not accumulated text
|
||||||
|
assert_eq!(outputs.len(), 3);
|
||||||
|
|
||||||
|
// Verify no text is repeated
|
||||||
|
for i in 0..outputs.len() {
|
||||||
|
for j in i + 1..outputs.len() {
|
||||||
|
// No output should contain another (no accumulation)
|
||||||
|
assert!(!outputs[j].contains(&outputs[i]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_stop_sequence_detection() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default().with_stop_sequence("test");
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process "Hello world"
|
||||||
|
decoder.process_token(1).unwrap(); // "Hello"
|
||||||
|
decoder.process_token(2).unwrap(); // "world"
|
||||||
|
|
||||||
|
// Process "test" which should trigger stop
|
||||||
|
let result = decoder.process_token(3).unwrap(); // "test"
|
||||||
|
|
||||||
|
// Should stop when we hit "test"
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flush_after_partial() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process a token
|
||||||
|
decoder.process_token(1).unwrap(); // "Hello"
|
||||||
|
|
||||||
|
// Flush should return any remaining text in jail
|
||||||
|
let result = decoder.flush();
|
||||||
|
|
||||||
|
// After processing, flush should work
|
||||||
|
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reset_functionality() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default().with_stop_token(999);
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process and stop
|
||||||
|
decoder.process_token(1).unwrap();
|
||||||
|
decoder.process_token(999).unwrap();
|
||||||
|
assert!(decoder.is_stopped());
|
||||||
|
|
||||||
|
// Reset should clear everything
|
||||||
|
decoder.reset();
|
||||||
|
assert!(!decoder.is_stopped());
|
||||||
|
|
||||||
|
// Should be able to process again
|
||||||
|
let result = decoder.process_token(2).unwrap();
|
||||||
|
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_visible_stop_sequence() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process "Hello"
|
||||||
|
decoder.process_token(1).unwrap();
|
||||||
|
|
||||||
|
// Process "world" - should include it in output
|
||||||
|
let result = decoder.process_token(2).unwrap();
|
||||||
|
|
||||||
|
if let SequenceDecoderOutput::StoppedWithText(text) = result {
|
||||||
|
// Should include "world" in the output
|
||||||
|
assert!(text.contains("world"));
|
||||||
|
} else {
|
||||||
|
panic!("Expected StoppedWithText with visible stop sequence");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_tokens_processing() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let config = StopSequenceConfig::default();
|
||||||
|
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
|
||||||
|
|
||||||
|
// Process multiple tokens at once
|
||||||
|
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
|
||||||
|
|
||||||
|
// Should get results for each token
|
||||||
|
assert_eq!(results.len(), 3);
|
||||||
|
|
||||||
|
// Each result should be Text (no stops configured)
|
||||||
|
for result in results {
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user