[router] add tokenizer download support from hf hub (#9882)
This commit is contained in:
@@ -4,9 +4,7 @@ version = "0.0.0"
|
|||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["huggingface", "grpc-client"]
|
default = ["grpc-client"]
|
||||||
huggingface = ["tokenizers", "minijinja"]
|
|
||||||
tiktoken = ["tiktoken-rs"]
|
|
||||||
grpc-client = []
|
grpc-client = []
|
||||||
grpc-server = []
|
grpc-server = []
|
||||||
|
|
||||||
@@ -52,10 +50,11 @@ regex = "1.10"
|
|||||||
url = "2.5.4"
|
url = "2.5.4"
|
||||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
tokenizers = { version = "0.21.4", optional = true }
|
tokenizers = { version = "0.22.0" }
|
||||||
tiktoken-rs = { version = "0.7.0", optional = true }
|
tiktoken-rs = { version = "0.7.0" }
|
||||||
minijinja = { version = "2.0", optional = true }
|
minijinja = { version = "2.0" }
|
||||||
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
|
||||||
|
hf-hub = { version = "0.4.3", features = ["tokio"] }
|
||||||
|
|
||||||
# gRPC and Protobuf dependencies
|
# gRPC and Protobuf dependencies
|
||||||
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio
|
|||||||
|
|
||||||
**Key Components:**
|
**Key Components:**
|
||||||
- **Factory Pattern**: Auto-detection and creation of appropriate tokenizer types from files or model names
|
- **Factory Pattern**: Auto-detection and creation of appropriate tokenizer types from files or model names
|
||||||
|
- **HuggingFace Hub Integration**: Automatic downloading of tokenizer files from HuggingFace Hub for model IDs
|
||||||
- **Trait System**: `Encoder`, `Decoder`, and `Tokenizer` traits for implementation flexibility
|
- **Trait System**: `Encoder`, `Decoder`, and `Tokenizer` traits for implementation flexibility
|
||||||
- **Streaming**: Incremental decoding with UTF-8 boundary handling and buffering
|
- **Streaming**: Incremental decoding with UTF-8 boundary handling and buffering
|
||||||
- **Stop Sequences**: Complex pattern matching for stop tokens and sequences with "jail" buffering
|
- **Stop Sequences**: Complex pattern matching for stop tokens and sequences with "jail" buffering
|
||||||
@@ -16,7 +17,7 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio
|
|||||||
- **Metrics Integration**: Comprehensive performance and error tracking across all operations
|
- **Metrics Integration**: Comprehensive performance and error tracking across all operations
|
||||||
|
|
||||||
**Data Flow:**
|
**Data Flow:**
|
||||||
1. Request → Factory (type detection) → Concrete Tokenizer Creation
|
1. Request → Factory (type detection/HF download) → Concrete Tokenizer Creation
|
||||||
2. Encode: Text → Tokenizer → Encoding (token IDs)
|
2. Encode: Text → Tokenizer → Encoding (token IDs)
|
||||||
3. Stream: Token IDs → DecodeStream → Incremental Text Chunks
|
3. Stream: Token IDs → DecodeStream → Incremental Text Chunks
|
||||||
4. Stop Detection: Tokens → StopSequenceDecoder → Text/Held/Stopped
|
4. Stop Detection: Tokens → StopSequenceDecoder → Text/Held/Stopped
|
||||||
@@ -25,8 +26,9 @@ The SGL Router tokenizer layer provides a unified interface for text tokenizatio
|
|||||||
### Architecture Highlights
|
### Architecture Highlights
|
||||||
|
|
||||||
- **Extended Backend Support**: HuggingFace, Tiktoken (GPT models), and Mock for testing
|
- **Extended Backend Support**: HuggingFace, Tiktoken (GPT models), and Mock for testing
|
||||||
|
- **HuggingFace Hub Integration**: Automatic tokenizer downloads with caching
|
||||||
- **Comprehensive Metrics**: Full TokenizerMetrics integration for observability
|
- **Comprehensive Metrics**: Full TokenizerMetrics integration for observability
|
||||||
- **Feature Gating**: Conditional compilation for tokenizer backends
|
- **Unified Dependencies**: All tokenizer backends included by default (no feature gates)
|
||||||
- **Stop Sequence Detection**: Sophisticated partial matching with jail buffer
|
- **Stop Sequence Detection**: Sophisticated partial matching with jail buffer
|
||||||
- **Chat Template Support**: Full Jinja2 rendering with HuggingFace compatibility
|
- **Chat Template Support**: Full Jinja2 rendering with HuggingFace compatibility
|
||||||
- **Thread Safety**: Arc-based sharing with Send + Sync guarantees
|
- **Thread Safety**: Arc-based sharing with Send + Sync guarantees
|
||||||
@@ -92,9 +94,14 @@ sequenceDiagram
|
|||||||
participant SD as StopDecoder
|
participant SD as StopDecoder
|
||||||
participant M as Metrics
|
participant M as Metrics
|
||||||
|
|
||||||
C->>F: create_tokenizer(path)
|
C->>F: create_tokenizer(path_or_model_id)
|
||||||
F->>F: detect_type()
|
F->>F: detect_type()
|
||||||
F->>T: new HF/Tiktoken/Mock
|
alt local file
|
||||||
|
F->>T: new HF/Tiktoken/Mock
|
||||||
|
else HuggingFace model ID
|
||||||
|
F->>F: download_tokenizer_from_hf()
|
||||||
|
F->>T: new from downloaded files
|
||||||
|
end
|
||||||
F->>M: record_factory_load()
|
F->>M: record_factory_load()
|
||||||
F-->>C: Arc<dyn Tokenizer>
|
F-->>C: Arc<dyn Tokenizer>
|
||||||
|
|
||||||
@@ -287,11 +294,11 @@ impl Tokenizer {
|
|||||||
- Single field: `Arc<dyn traits::Tokenizer>` for polymorphic dispatch
|
- Single field: `Arc<dyn traits::Tokenizer>` for polymorphic dispatch
|
||||||
- Immutable after creation, Clone via Arc
|
- Immutable after creation, Clone via Arc
|
||||||
|
|
||||||
**Re-exports** (mod.rs:25-39):
|
**Re-exports** (mod.rs:26-43):
|
||||||
- Factory functions: `create_tokenizer`, `create_tokenizer_from_file`, `create_tokenizer_with_chat_template`
|
- Factory functions: `create_tokenizer`, `create_tokenizer_async`, `create_tokenizer_from_file`, `create_tokenizer_with_chat_template`
|
||||||
- Types: `Sequence`, `StopSequenceConfig`, `DecodeStream`, `Encoding`
|
- Types: `Sequence`, `StopSequenceConfig`, `DecodeStream`, `Encoding`, `TokenizerType`
|
||||||
- Chat template: `ChatMessage` (when huggingface feature enabled)
|
- Chat template: `ChatMessage`
|
||||||
- Conditional: `HuggingFaceTokenizer`, `TiktokenTokenizer` based on features
|
- Tokenizer implementations: `HuggingFaceTokenizer`, `TiktokenTokenizer`
|
||||||
|
|
||||||
### 3.2 traits.rs (Trait Definitions)
|
### 3.2 traits.rs (Trait Definitions)
|
||||||
|
|
||||||
@@ -350,6 +357,7 @@ pub fn create_tokenizer_with_chat_template(
|
|||||||
chat_template_path: Option<&str>
|
chat_template_path: Option<&str>
|
||||||
) -> Result<Arc<dyn traits::Tokenizer>>
|
) -> Result<Arc<dyn traits::Tokenizer>>
|
||||||
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
|
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
|
||||||
|
pub async fn create_tokenizer_async(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
|
||||||
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType>
|
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType>
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -364,10 +372,16 @@ pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType>
|
|||||||
- SentencePiece: Check for specific byte patterns
|
- SentencePiece: Check for specific byte patterns
|
||||||
- GGUF: Check magic number "GGUF"
|
- GGUF: Check magic number "GGUF"
|
||||||
|
|
||||||
**Model Name Routing** (factory.rs:163-203):
|
**Model Name Routing** (factory.rs:145-193):
|
||||||
- GPT models → Tiktoken (gpt-4, gpt-3.5, davinci, curie, etc.)
|
- GPT models → Tiktoken (gpt-4, gpt-3.5, davinci, curie, etc.)
|
||||||
- File paths → file-based creation
|
- File paths → file-based creation
|
||||||
- HuggingFace Hub → Not implemented (returns error)
|
- HuggingFace model IDs → Automatic download from Hub
|
||||||
|
|
||||||
|
**HuggingFace Hub Integration**:
|
||||||
|
- Downloads tokenizer files (tokenizer.json, tokenizer_config.json, etc.)
|
||||||
|
- Respects HF_TOKEN environment variable for private models
|
||||||
|
- Caches downloaded files using hf-hub crate
|
||||||
|
- Async and blocking versions available
|
||||||
|
|
||||||
**Metrics Integration:**
|
**Metrics Integration:**
|
||||||
- Records factory load/error events (factory.rs:56-57, 82-83)
|
- Records factory load/error events (factory.rs:56-57, 82-83)
|
||||||
@@ -613,7 +627,32 @@ pub enum TiktokenModel {
|
|||||||
- Decode: Join tokens with spaces
|
- Decode: Join tokens with spaces
|
||||||
- Skips special tokens when requested
|
- Skips special tokens when requested
|
||||||
|
|
||||||
### 3.10 chat_template.rs (Chat Template Support)
|
### 3.10 hub.rs (HuggingFace Hub Download)
|
||||||
|
|
||||||
|
**Location**: `src/tokenizer/hub.rs`
|
||||||
|
|
||||||
|
**Purpose:** Download tokenizer files from HuggingFace Hub when given a model ID.
|
||||||
|
|
||||||
|
**Key Functions:**
|
||||||
|
|
||||||
|
```rust
|
||||||
|
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> Result<PathBuf>
|
||||||
|
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> Result<PathBuf>
|
||||||
|
```
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Downloads only tokenizer-related files by default
|
||||||
|
- Filters out model weights, images, and documentation
|
||||||
|
- Uses HF_TOKEN environment variable for authentication
|
||||||
|
- Returns cached directory path for subsequent use
|
||||||
|
- Progress indication during download
|
||||||
|
|
||||||
|
**File Detection:**
|
||||||
|
- Tokenizer files: tokenizer.json, tokenizer_config.json, special_tokens_map.json
|
||||||
|
- Vocabulary files: vocab.json, merges.txt
|
||||||
|
- SentencePiece models: *.model files
|
||||||
|
|
||||||
|
### 3.11 chat_template.rs (Chat Template Support)
|
||||||
|
|
||||||
**Location**: `src/tokenizer/chat_template.rs`
|
**Location**: `src/tokenizer/chat_template.rs`
|
||||||
|
|
||||||
@@ -894,11 +933,11 @@ The `Encoding` enum must:
|
|||||||
### Configuration
|
### Configuration
|
||||||
|
|
||||||
**Environment Variables:**
|
**Environment Variables:**
|
||||||
- None currently defined
|
- `HF_TOKEN`: HuggingFace authentication token for private models
|
||||||
|
|
||||||
**Feature Flags:**
|
**Dependencies:**
|
||||||
- `huggingface`: Enable HF tokenizer
|
- All tokenizer backends included by default
|
||||||
- `tiktoken`: Enable Tiktoken support
|
- No feature flags required
|
||||||
|
|
||||||
**Model Mapping:**
|
**Model Mapping:**
|
||||||
- Hardcoded in factory.rs
|
- Hardcoded in factory.rs
|
||||||
@@ -961,26 +1000,22 @@ The `Encoding` enum must:
|
|||||||
- File: `src/tokenizer/traits.rs`
|
- File: `src/tokenizer/traits.rs`
|
||||||
- Symbol: `pub type Offsets = (usize, usize)`
|
- Symbol: `pub type Offsets = (usize, usize)`
|
||||||
|
|
||||||
3. **TODO:** Implement HuggingFace Hub downloading
|
3. **TODO:** Support SentencePiece models
|
||||||
- File: `src/tokenizer/factory.rs:191`
|
|
||||||
- Symbol: `create_tokenizer()` function
|
|
||||||
|
|
||||||
4. **TODO:** Support SentencePiece models
|
|
||||||
- File: `src/tokenizer/factory.rs:69-72`
|
- File: `src/tokenizer/factory.rs:69-72`
|
||||||
- Symbol: Extension match arm for "model"
|
- Symbol: Extension match arm for "model"
|
||||||
|
|
||||||
5. **TODO:** Support GGUF format
|
4. **TODO:** Support GGUF format
|
||||||
- File: `src/tokenizer/factory.rs:74-78`
|
- File: `src/tokenizer/factory.rs:74-78`
|
||||||
- Symbol: Extension match arm for "gguf"
|
- Symbol: Extension match arm for "gguf"
|
||||||
|
|
||||||
6. **TODO:** Add token↔ID mapping for Tiktoken
|
5. **TODO:** Add token↔ID mapping for Tiktoken
|
||||||
- File: `src/tokenizer/tiktoken.rs:151-161`
|
- File: `src/tokenizer/tiktoken.rs:151-161`
|
||||||
- Symbol: `token_to_id()` and `id_to_token()` methods
|
- Symbol: `token_to_id()` and `id_to_token()` methods
|
||||||
|
|
||||||
7. **TODO:** Fix `token_ids_ref()` for Tiktoken
|
6. **TODO:** Fix `token_ids_ref()` for Tiktoken
|
||||||
- File: `src/tokenizer/traits.rs:46-50`
|
- File: `src/tokenizer/traits.rs:46-50`
|
||||||
- Symbol: `Encoding::Tiktoken` match arm
|
- Symbol: `Encoding::Tiktoken` match arm
|
||||||
|
|
||||||
8. **TODO:** Make model→tokenizer mapping configurable
|
7. **TODO:** Make model→tokenizer mapping configurable
|
||||||
- File: `src/tokenizer/factory.rs:174-184`
|
- File: `src/tokenizer/factory.rs:174-184`
|
||||||
- Symbol: GPT model detection logic
|
- Symbol: GPT model detection logic
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
//! similar to HuggingFace transformers' apply_chat_template method.
|
//! similar to HuggingFace transformers' apply_chat_template method.
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
use minijinja::{context, Environment, Value};
|
use minijinja::{context, Environment, Value};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json;
|
use serde_json;
|
||||||
@@ -38,14 +37,12 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Chat template processor using Jinja2
|
/// Chat template processor using Jinja2
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
pub struct ChatTemplateProcessor {
|
pub struct ChatTemplateProcessor {
|
||||||
template: String,
|
template: String,
|
||||||
bos_token: Option<String>,
|
bos_token: Option<String>,
|
||||||
eos_token: Option<String>,
|
eos_token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
impl ChatTemplateProcessor {
|
impl ChatTemplateProcessor {
|
||||||
/// Create a new chat template processor
|
/// Create a new chat template processor
|
||||||
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
|
||||||
@@ -102,7 +99,6 @@ impl ChatTemplateProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Load chat template from tokenizer config JSON
|
/// Load chat template from tokenizer config JSON
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
@@ -136,7 +132,6 @@ mod tests {
|
|||||||
assert_eq!(assistant_msg.role, "assistant");
|
assert_eq!(assistant_msg.role, "assistant");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_simple_chat_template() {
|
fn test_simple_chat_template() {
|
||||||
// Simple template that formats messages
|
// Simple template that formats messages
|
||||||
@@ -162,7 +157,6 @@ assistant:
|
|||||||
assert!(result.contains("assistant:"));
|
assert!(result.contains("assistant:"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_template_with_tokens() {
|
fn test_chat_template_with_tokens() {
|
||||||
// Template that uses special tokens
|
// Template that uses special tokens
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ use std::io::Read;
|
|||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
use super::huggingface::HuggingFaceTokenizer;
|
use super::huggingface::HuggingFaceTokenizer;
|
||||||
|
use super::tiktoken::TiktokenTokenizer;
|
||||||
|
use crate::tokenizer::hub::download_tokenizer_from_hf;
|
||||||
|
|
||||||
/// Represents the type of tokenizer being used
|
/// Represents the type of tokenizer being used
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum TokenizerType {
|
pub enum TokenizerType {
|
||||||
HuggingFace(String),
|
HuggingFace(String),
|
||||||
Mock,
|
Mock,
|
||||||
#[cfg(feature = "tiktoken")]
|
|
||||||
Tiktoken(String),
|
Tiktoken(String),
|
||||||
// Future: SentencePiece, GGUF
|
// Future: SentencePiece, GGUF
|
||||||
}
|
}
|
||||||
@@ -52,21 +52,10 @@ pub fn create_tokenizer_with_chat_template(
|
|||||||
|
|
||||||
let result = match extension.as_deref() {
|
let result = match extension.as_deref() {
|
||||||
Some("json") => {
|
Some("json") => {
|
||||||
#[cfg(feature = "huggingface")]
|
let tokenizer =
|
||||||
{
|
HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
|
||||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
|
||||||
file_path,
|
|
||||||
chat_template_path,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
|
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
|
||||||
}
|
|
||||||
#[cfg(not(feature = "huggingface"))]
|
|
||||||
{
|
|
||||||
Err(Error::msg(
|
|
||||||
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Some("model") => {
|
Some("model") => {
|
||||||
// SentencePiece model file
|
// SentencePiece model file
|
||||||
@@ -94,17 +83,8 @@ fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
|
|||||||
|
|
||||||
// Check for JSON (HuggingFace format)
|
// Check for JSON (HuggingFace format)
|
||||||
if is_likely_json(&buffer) {
|
if is_likely_json(&buffer) {
|
||||||
#[cfg(feature = "huggingface")]
|
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||||
{
|
return Ok(Arc::new(tokenizer));
|
||||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
|
||||||
return Ok(Arc::new(tokenizer));
|
|
||||||
}
|
|
||||||
#[cfg(not(feature = "huggingface"))]
|
|
||||||
{
|
|
||||||
return Err(Error::msg(
|
|
||||||
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check for GGUF magic number
|
// Check for GGUF magic number
|
||||||
@@ -154,7 +134,57 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|
|||||||
|| buffer.windows(4).any(|w| w == b"</s>"))
|
|| buffer.windows(4).any(|w| w == b"</s>"))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Factory function to create tokenizer from a model name or path
|
/// Factory function to create tokenizer from a model name or path (async version)
|
||||||
|
pub async fn create_tokenizer_async(
|
||||||
|
model_name_or_path: &str,
|
||||||
|
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
|
// Check if it's a file path
|
||||||
|
let path = Path::new(model_name_or_path);
|
||||||
|
if path.exists() {
|
||||||
|
return create_tokenizer_from_file(model_name_or_path);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's a GPT model name that should use Tiktoken
|
||||||
|
if model_name_or_path.contains("gpt-")
|
||||||
|
|| model_name_or_path.contains("davinci")
|
||||||
|
|| model_name_or_path.contains("curie")
|
||||||
|
|| model_name_or_path.contains("babbage")
|
||||||
|
|| model_name_or_path.contains("ada")
|
||||||
|
{
|
||||||
|
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||||
|
return Ok(Arc::new(tokenizer));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to download tokenizer files from HuggingFace
|
||||||
|
match download_tokenizer_from_hf(model_name_or_path).await {
|
||||||
|
Ok(cache_dir) => {
|
||||||
|
// Look for tokenizer.json in the cache directory
|
||||||
|
let tokenizer_path = cache_dir.join("tokenizer.json");
|
||||||
|
if tokenizer_path.exists() {
|
||||||
|
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
|
||||||
|
} else {
|
||||||
|
// Try other common tokenizer file names
|
||||||
|
let possible_files = ["tokenizer_config.json", "vocab.json"];
|
||||||
|
for file_name in &possible_files {
|
||||||
|
let file_path = cache_dir.join(file_name);
|
||||||
|
if file_path.exists() {
|
||||||
|
return create_tokenizer_from_file(file_path.to_str().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(Error::msg(format!(
|
||||||
|
"Downloaded model '{}' but couldn't find a suitable tokenizer file",
|
||||||
|
model_name_or_path
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Err(Error::msg(format!(
|
||||||
|
"Failed to download tokenizer from HuggingFace: {}",
|
||||||
|
e
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Factory function to create tokenizer from a model name or path (blocking version)
|
||||||
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||||
// Check if it's a file path
|
// Check if it's a file path
|
||||||
let path = Path::new(model_name_or_path);
|
let path = Path::new(model_name_or_path);
|
||||||
@@ -163,35 +193,25 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if it's a GPT model name that should use Tiktoken
|
// Check if it's a GPT model name that should use Tiktoken
|
||||||
#[cfg(feature = "tiktoken")]
|
if model_name_or_path.contains("gpt-")
|
||||||
|
|| model_name_or_path.contains("davinci")
|
||||||
|
|| model_name_or_path.contains("curie")
|
||||||
|
|| model_name_or_path.contains("babbage")
|
||||||
|
|| model_name_or_path.contains("ada")
|
||||||
{
|
{
|
||||||
if model_name_or_path.contains("gpt-")
|
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||||
|| model_name_or_path.contains("davinci")
|
return Ok(Arc::new(tokenizer));
|
||||||
|| model_name_or_path.contains("curie")
|
|
||||||
|| model_name_or_path.contains("babbage")
|
|
||||||
|| model_name_or_path.contains("ada")
|
|
||||||
{
|
|
||||||
use super::tiktoken::TiktokenTokenizer;
|
|
||||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
|
||||||
return Ok(Arc::new(tokenizer));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Otherwise, try to load from HuggingFace Hub
|
// Only use tokio for HuggingFace downloads
|
||||||
#[cfg(feature = "huggingface")]
|
// Check if we're already in a tokio runtime
|
||||||
{
|
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||||
// This would download from HF Hub - not implemented yet
|
// We're in a runtime, use block_in_place
|
||||||
Err(Error::msg(
|
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path)))
|
||||||
"Loading from HuggingFace Hub not yet implemented",
|
} else {
|
||||||
))
|
// No runtime, create a temporary one
|
||||||
}
|
let rt = tokio::runtime::Runtime::new()?;
|
||||||
|
rt.block_on(create_tokenizer_async(model_name_or_path))
|
||||||
#[cfg(not(feature = "huggingface"))]
|
|
||||||
{
|
|
||||||
Err(Error::msg(format!(
|
|
||||||
"Model '{}' not found locally and HuggingFace support is not enabled",
|
|
||||||
model_name_or_path
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -257,7 +277,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "tiktoken")]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_create_tiktoken_tokenizer() {
|
fn test_create_tiktoken_tokenizer() {
|
||||||
// Test creating tokenizer for GPT models
|
// Test creating tokenizer for GPT models
|
||||||
@@ -270,4 +289,30 @@ mod tests {
|
|||||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||||
assert_eq!(decoded, text);
|
assert_eq!(decoded, text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_download_tokenizer_from_hf() {
|
||||||
|
// Test with a small model that should have tokenizer files
|
||||||
|
// Skip this test if HF_TOKEN is not set and we're in CI
|
||||||
|
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
|
||||||
|
println!("Skipping HF download test in CI without HF_TOKEN");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to create tokenizer for a known small model
|
||||||
|
let result = create_tokenizer_async("bert-base-uncased").await;
|
||||||
|
|
||||||
|
// The test might fail due to network issues or rate limiting
|
||||||
|
// so we just check that the function executes without panic
|
||||||
|
match result {
|
||||||
|
Ok(tokenizer) => {
|
||||||
|
assert!(tokenizer.vocab_size() > 0);
|
||||||
|
println!("Successfully downloaded and created tokenizer");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
println!("Download failed (this might be expected): {}", e);
|
||||||
|
// Don't fail the test - network issues shouldn't break CI
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
238
sgl-router/src/tokenizer/hub.rs
Normal file
238
sgl-router/src/tokenizer/hub.rs
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
use hf_hub::api::tokio::ApiBuilder;
|
||||||
|
use std::env;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
const IGNORED: [&str; 5] = [
|
||||||
|
".gitattributes",
|
||||||
|
"LICENSE",
|
||||||
|
"LICENSE.txt",
|
||||||
|
"README.md",
|
||||||
|
"USE_POLICY.md",
|
||||||
|
];
|
||||||
|
|
||||||
|
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
|
||||||
|
|
||||||
|
/// Checks if a file is a model weight file
|
||||||
|
fn is_weight_file(filename: &str) -> bool {
|
||||||
|
filename.ends_with(".bin")
|
||||||
|
|| filename.ends_with(".safetensors")
|
||||||
|
|| filename.ends_with(".h5")
|
||||||
|
|| filename.ends_with(".msgpack")
|
||||||
|
|| filename.ends_with(".ckpt.index")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if a file is an image file
|
||||||
|
fn is_image(filename: &str) -> bool {
|
||||||
|
filename.ends_with(".png")
|
||||||
|
|| filename.ends_with("PNG")
|
||||||
|
|| filename.ends_with(".jpg")
|
||||||
|
|| filename.ends_with("JPG")
|
||||||
|
|| filename.ends_with(".jpeg")
|
||||||
|
|| filename.ends_with("JPEG")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks if a file is a tokenizer file
|
||||||
|
fn is_tokenizer_file(filename: &str) -> bool {
|
||||||
|
filename.ends_with("tokenizer.json")
|
||||||
|
|| filename.ends_with("tokenizer_config.json")
|
||||||
|
|| filename.ends_with("special_tokens_map.json")
|
||||||
|
|| filename.ends_with("vocab.json")
|
||||||
|
|| filename.ends_with("merges.txt")
|
||||||
|
|| filename.ends_with(".model") // SentencePiece models
|
||||||
|
|| filename.ends_with(".tiktoken")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt to download tokenizer files from Hugging Face
|
||||||
|
/// Returns the directory containing the downloaded tokenizer files
|
||||||
|
pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::Result<PathBuf> {
|
||||||
|
let model_id = model_id.as_ref();
|
||||||
|
let token = env::var(HF_TOKEN_ENV_VAR).ok();
|
||||||
|
let api = ApiBuilder::new()
|
||||||
|
.with_progress(true)
|
||||||
|
.with_token(token)
|
||||||
|
.build()?;
|
||||||
|
let model_name = model_id.display().to_string();
|
||||||
|
|
||||||
|
let repo = api.model(model_name.clone());
|
||||||
|
|
||||||
|
let info = match repo.info().await {
|
||||||
|
Ok(info) => info,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
|
||||||
|
model_name,
|
||||||
|
e
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if info.siblings.is_empty() {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Model '{}' exists but contains no downloadable files.",
|
||||||
|
model_name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut cache_dir = None;
|
||||||
|
let mut tokenizer_files_found = false;
|
||||||
|
|
||||||
|
// First, identify all tokenizer files to download
|
||||||
|
let tokenizer_files: Vec<_> = info
|
||||||
|
.siblings
|
||||||
|
.iter()
|
||||||
|
.filter(|sib| {
|
||||||
|
!IGNORED.contains(&sib.rfilename.as_str())
|
||||||
|
&& !is_image(&sib.rfilename)
|
||||||
|
&& !is_weight_file(&sib.rfilename)
|
||||||
|
&& is_tokenizer_file(&sib.rfilename)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if tokenizer_files.is_empty() {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"No tokenizer files found for model '{}'.",
|
||||||
|
model_name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download all tokenizer files
|
||||||
|
for sib in tokenizer_files {
|
||||||
|
match repo.get(&sib.rfilename).await {
|
||||||
|
Ok(path) => {
|
||||||
|
if cache_dir.is_none() {
|
||||||
|
cache_dir = path.parent().map(|p| p.to_path_buf());
|
||||||
|
}
|
||||||
|
tokenizer_files_found = true;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Failed to download tokenizer file '{}' from model '{}': {}",
|
||||||
|
sib.rfilename,
|
||||||
|
model_name,
|
||||||
|
e
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tokenizer_files_found {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"No tokenizer files could be downloaded for model '{}'.",
|
||||||
|
model_name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
match cache_dir {
|
||||||
|
Some(dir) => Ok(dir),
|
||||||
|
None => Err(anyhow::anyhow!(
|
||||||
|
"Invalid HF cache path for model '{}'",
|
||||||
|
model_name
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt to download a model from Hugging Face (including weights)
|
||||||
|
/// Returns the directory it is in
|
||||||
|
/// If ignore_weights is true, model weight files will be skipped
|
||||||
|
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
|
||||||
|
let name = name.as_ref();
|
||||||
|
let token = env::var(HF_TOKEN_ENV_VAR).ok();
|
||||||
|
let api = ApiBuilder::new()
|
||||||
|
.with_progress(true)
|
||||||
|
.with_token(token)
|
||||||
|
.build()?;
|
||||||
|
let model_name = name.display().to_string();
|
||||||
|
|
||||||
|
let repo = api.model(model_name.clone());
|
||||||
|
|
||||||
|
let info = match repo.info().await {
|
||||||
|
Ok(info) => info,
|
||||||
|
Err(e) => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Failed to fetch model '{}' from HuggingFace: {}. Is this a valid HuggingFace ID?",
|
||||||
|
model_name,
|
||||||
|
e
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if info.siblings.is_empty() {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Model '{}' exists but contains no downloadable files.",
|
||||||
|
model_name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut p = PathBuf::new();
|
||||||
|
let mut files_downloaded = false;
|
||||||
|
|
||||||
|
for sib in info.siblings {
|
||||||
|
if IGNORED.contains(&sib.rfilename.as_str()) || is_image(&sib.rfilename) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If ignore_weights is true, skip weight files
|
||||||
|
if ignore_weights && is_weight_file(&sib.rfilename) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match repo.get(&sib.rfilename).await {
|
||||||
|
Ok(path) => {
|
||||||
|
p = path;
|
||||||
|
files_downloaded = true;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"Failed to download file '{}' from model '{}': {}",
|
||||||
|
sib.rfilename,
|
||||||
|
model_name,
|
||||||
|
e
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !files_downloaded {
|
||||||
|
let file_type = if ignore_weights {
|
||||||
|
"non-weight"
|
||||||
|
} else {
|
||||||
|
"valid"
|
||||||
|
};
|
||||||
|
return Err(anyhow::anyhow!(
|
||||||
|
"No {} files found for model '{}'.",
|
||||||
|
file_type,
|
||||||
|
model_name
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
match p.parent() {
|
||||||
|
Some(p) => Ok(p.to_path_buf()),
|
||||||
|
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_tokenizer_file() {
|
||||||
|
assert!(is_tokenizer_file("tokenizer.json"));
|
||||||
|
assert!(is_tokenizer_file("tokenizer_config.json"));
|
||||||
|
assert!(is_tokenizer_file("special_tokens_map.json"));
|
||||||
|
assert!(is_tokenizer_file("vocab.json"));
|
||||||
|
assert!(is_tokenizer_file("merges.txt"));
|
||||||
|
assert!(is_tokenizer_file("spiece.model"));
|
||||||
|
assert!(!is_tokenizer_file("model.bin"));
|
||||||
|
assert!(!is_tokenizer_file("README.md"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_weight_file() {
|
||||||
|
assert!(is_weight_file("model.bin"));
|
||||||
|
assert!(is_weight_file("model.safetensors"));
|
||||||
|
assert!(is_weight_file("pytorch_model.bin"));
|
||||||
|
assert!(!is_weight_file("tokenizer.json"));
|
||||||
|
assert!(!is_weight_file("config.json"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ use anyhow::{Error, Result};
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
|
||||||
|
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||||
|
|
||||||
/// HuggingFace tokenizer wrapper
|
/// HuggingFace tokenizer wrapper
|
||||||
@@ -14,7 +13,6 @@ pub struct HuggingFaceTokenizer {
|
|||||||
special_tokens: SpecialTokens,
|
special_tokens: SpecialTokens,
|
||||||
vocab: HashMap<String, TokenIdType>,
|
vocab: HashMap<String, TokenIdType>,
|
||||||
reverse_vocab: HashMap<TokenIdType, String>,
|
reverse_vocab: HashMap<TokenIdType, String>,
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
chat_template: Option<String>,
|
chat_template: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,7 +41,6 @@ impl HuggingFaceTokenizer {
|
|||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Load chat template
|
// Load chat template
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
let chat_template = if let Some(template_path) = chat_template_path {
|
let chat_template = if let Some(template_path) = chat_template_path {
|
||||||
// Load from specified .jinja file
|
// Load from specified .jinja file
|
||||||
Self::load_chat_template_from_file(template_path)?
|
Self::load_chat_template_from_file(template_path)?
|
||||||
@@ -57,7 +54,6 @@ impl HuggingFaceTokenizer {
|
|||||||
special_tokens,
|
special_tokens,
|
||||||
vocab,
|
vocab,
|
||||||
reverse_vocab,
|
reverse_vocab,
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
chat_template,
|
chat_template,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -76,7 +72,6 @@ impl HuggingFaceTokenizer {
|
|||||||
special_tokens,
|
special_tokens,
|
||||||
vocab,
|
vocab,
|
||||||
reverse_vocab,
|
reverse_vocab,
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
chat_template: None,
|
chat_template: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -109,7 +104,6 @@ impl HuggingFaceTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Try to load chat template from tokenizer_config.json
|
/// Try to load chat template from tokenizer_config.json
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
|
||||||
// Try to find tokenizer_config.json in the same directory
|
// Try to find tokenizer_config.json in the same directory
|
||||||
let path = std::path::Path::new(tokenizer_path);
|
let path = std::path::Path::new(tokenizer_path);
|
||||||
@@ -127,7 +121,6 @@ impl HuggingFaceTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Load chat template from a .jinja file
|
/// Load chat template from a .jinja file
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
||||||
@@ -141,13 +134,11 @@ impl HuggingFaceTokenizer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Set or override the chat template
|
/// Set or override the chat template
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
pub fn set_chat_template(&mut self, template: String) {
|
pub fn set_chat_template(&mut self, template: String) {
|
||||||
self.chat_template = Some(template);
|
self.chat_template = Some(template);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply chat template if available
|
/// Apply chat template if available
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
pub fn apply_chat_template(
|
pub fn apply_chat_template(
|
||||||
&self,
|
&self,
|
||||||
messages: &[ChatMessage],
|
messages: &[ChatMessage],
|
||||||
@@ -172,24 +163,6 @@ impl HuggingFaceTokenizer {
|
|||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Apply chat template if available (without minijinja feature)
|
|
||||||
#[cfg(not(feature = "minijinja"))]
|
|
||||||
pub fn apply_chat_template(
|
|
||||||
&self,
|
|
||||||
messages: &[ChatMessage],
|
|
||||||
add_generation_prompt: bool,
|
|
||||||
) -> Result<String> {
|
|
||||||
// Fallback to simple formatting
|
|
||||||
let mut result = String::new();
|
|
||||||
for msg in messages {
|
|
||||||
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
|
|
||||||
}
|
|
||||||
if add_generation_prompt {
|
|
||||||
result.push_str("assistant: ");
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encoder for HuggingFaceTokenizer {
|
impl Encoder for HuggingFaceTokenizer {
|
||||||
@@ -241,10 +214,8 @@ impl TokenizerTrait for HuggingFaceTokenizer {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
use super::ChatMessage;
|
use super::ChatMessage;
|
||||||
|
|
||||||
#[cfg(feature = "minijinja")]
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_chat_message_creation() {
|
fn test_chat_message_creation() {
|
||||||
let msg = ChatMessage::system("You are a helpful assistant");
|
let msg = ChatMessage::system("You are a helpful assistant");
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use std::ops::Deref;
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
|
pub mod hub;
|
||||||
pub mod mock;
|
pub mod mock;
|
||||||
pub mod sequence;
|
pub mod sequence;
|
||||||
pub mod stop;
|
pub mod stop;
|
||||||
@@ -10,13 +11,11 @@ pub mod stream;
|
|||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
// Feature-gated modules
|
// Feature-gated modules
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
pub mod chat_template;
|
pub mod chat_template;
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
pub mod huggingface;
|
pub mod huggingface;
|
||||||
|
|
||||||
#[cfg(feature = "tiktoken")]
|
|
||||||
pub mod tiktoken;
|
pub mod tiktoken;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -24,21 +23,18 @@ mod tests;
|
|||||||
|
|
||||||
// Re-exports
|
// Re-exports
|
||||||
pub use factory::{
|
pub use factory::{
|
||||||
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template,
|
create_tokenizer, create_tokenizer_async, create_tokenizer_from_file,
|
||||||
TokenizerType,
|
create_tokenizer_with_chat_template, TokenizerType,
|
||||||
};
|
};
|
||||||
pub use sequence::Sequence;
|
pub use sequence::Sequence;
|
||||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||||
pub use stream::DecodeStream;
|
pub use stream::DecodeStream;
|
||||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
pub use huggingface::HuggingFaceTokenizer;
|
pub use huggingface::HuggingFaceTokenizer;
|
||||||
|
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
pub use chat_template::ChatMessage;
|
pub use chat_template::ChatMessage;
|
||||||
|
|
||||||
#[cfg(feature = "tiktoken")]
|
|
||||||
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||||
|
|
||||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ mod tests {
|
|||||||
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_chat_message_helpers() {
|
fn test_chat_message_helpers() {
|
||||||
let system_msg = ChatMessage::system("You are a helpful assistant");
|
let system_msg = ChatMessage::system("You are a helpful assistant");
|
||||||
assert_eq!(system_msg.role, "system");
|
assert_eq!(system_msg.role, "system");
|
||||||
@@ -19,7 +18,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_llama_style_template() {
|
fn test_llama_style_template() {
|
||||||
// Test a Llama-style chat template
|
// Test a Llama-style chat template
|
||||||
let template = r#"
|
let template = r#"
|
||||||
@@ -67,7 +65,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_chatml_template() {
|
fn test_chatml_template() {
|
||||||
// Test a ChatML-style template
|
// Test a ChatML-style template
|
||||||
let template = r#"
|
let template = r#"
|
||||||
@@ -97,7 +94,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_template_without_generation_prompt() {
|
fn test_template_without_generation_prompt() {
|
||||||
let template = r#"
|
let template = r#"
|
||||||
{%- for message in messages -%}
|
{%- for message in messages -%}
|
||||||
@@ -122,7 +118,6 @@ assistant:
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_template_with_special_tokens() {
|
fn test_template_with_special_tokens() {
|
||||||
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
||||||
|
|
||||||
@@ -139,7 +134,6 @@ assistant:
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_empty_messages() {
|
fn test_empty_messages() {
|
||||||
let template =
|
let template =
|
||||||
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ mod tests {
|
|||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_load_chat_template_from_file() {
|
fn test_load_chat_template_from_file() {
|
||||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
@@ -73,7 +72,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_override_existing_template() {
|
fn test_override_existing_template() {
|
||||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
@@ -136,7 +134,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "huggingface")]
|
|
||||||
fn test_set_chat_template_after_creation() {
|
fn test_set_chat_template_after_creation() {
|
||||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||||
|
|||||||
Reference in New Issue
Block a user