[router] add tokenizer download support from hf hub (#9882)

This commit is contained in:
Chang Su
2025-09-01 10:40:37 -07:00
committed by GitHub
parent b361750a4a
commit 598c0bc19d
9 changed files with 407 additions and 138 deletions

View File

@@ -5,15 +5,15 @@ use std::io::Read;
use std::path::Path;
use std::sync::Arc;
#[cfg(feature = "huggingface")]
use super::huggingface::HuggingFaceTokenizer;
use super::tiktoken::TiktokenTokenizer;
use crate::tokenizer::hub::download_tokenizer_from_hf;
/// Represents the type of tokenizer being used
#[derive(Debug, Clone)]
pub enum TokenizerType {
HuggingFace(String),
Mock,
#[cfg(feature = "tiktoken")]
Tiktoken(String),
// Future: SentencePiece, GGUF
}
@@ -52,21 +52,10 @@ pub fn create_tokenizer_with_chat_template(
let result = match extension.as_deref() {
Some("json") => {
#[cfg(feature = "huggingface")]
{
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
file_path,
chat_template_path,
)?;
let tokenizer =
HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
}
#[cfg(not(feature = "huggingface"))]
{
Err(Error::msg(
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
))
}
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
}
Some("model") => {
// SentencePiece model file
@@ -94,17 +83,8 @@ fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
// Check for JSON (HuggingFace format)
if is_likely_json(&buffer) {
#[cfg(feature = "huggingface")]
{
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
return Ok(Arc::new(tokenizer));
}
#[cfg(not(feature = "huggingface"))]
{
return Err(Error::msg(
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
));
}
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
return Ok(Arc::new(tokenizer));
}
// Check for GGUF magic number
@@ -154,7 +134,57 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|| buffer.windows(4).any(|w| w == b"</s>"))
}
/// Factory function to create tokenizer from a model name or path
/// Factory function to create tokenizer from a model name or path (async version)
pub async fn create_tokenizer_async(
model_name_or_path: &str,
) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path
let path = Path::new(model_name_or_path);
if path.exists() {
return create_tokenizer_from_file(model_name_or_path);
}
// Check if it's a GPT model name that should use Tiktoken
if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
return Ok(Arc::new(tokenizer));
}
// Try to download tokenizer files from HuggingFace
match download_tokenizer_from_hf(model_name_or_path).await {
Ok(cache_dir) => {
// Look for tokenizer.json in the cache directory
let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() {
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
} else {
// Try other common tokenizer file names
let possible_files = ["tokenizer_config.json", "vocab.json"];
for file_name in &possible_files {
let file_path = cache_dir.join(file_name);
if file_path.exists() {
return create_tokenizer_from_file(file_path.to_str().unwrap());
}
}
Err(Error::msg(format!(
"Downloaded model '{}' but couldn't find a suitable tokenizer file",
model_name_or_path
)))
}
}
Err(e) => Err(Error::msg(format!(
"Failed to download tokenizer from HuggingFace: {}",
e
))),
}
}
/// Factory function to create tokenizer from a model name or path (blocking version)
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path
let path = Path::new(model_name_or_path);
@@ -163,35 +193,25 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
}
// Check if it's a GPT model name that should use Tiktoken
#[cfg(feature = "tiktoken")]
if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{
if model_name_or_path.contains("gpt-")
|| model_name_or_path.contains("davinci")
|| model_name_or_path.contains("curie")
|| model_name_or_path.contains("babbage")
|| model_name_or_path.contains("ada")
{
use super::tiktoken::TiktokenTokenizer;
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
return Ok(Arc::new(tokenizer));
}
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
return Ok(Arc::new(tokenizer));
}
// Otherwise, try to load from HuggingFace Hub
#[cfg(feature = "huggingface")]
{
// This would download from HF Hub - not implemented yet
Err(Error::msg(
"Loading from HuggingFace Hub not yet implemented",
))
}
#[cfg(not(feature = "huggingface"))]
{
Err(Error::msg(format!(
"Model '{}' not found locally and HuggingFace support is not enabled",
model_name_or_path
)))
// Only use tokio for HuggingFace downloads
// Check if we're already in a tokio runtime
if let Ok(handle) = tokio::runtime::Handle::try_current() {
// We're in a runtime, use block_in_place
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path)))
} else {
// No runtime, create a temporary one
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(create_tokenizer_async(model_name_or_path))
}
}
@@ -257,7 +277,6 @@ mod tests {
}
}
#[cfg(feature = "tiktoken")]
#[test]
fn test_create_tiktoken_tokenizer() {
// Test creating tokenizer for GPT models
@@ -270,4 +289,30 @@ mod tests {
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text);
}
#[tokio::test]
async fn test_download_tokenizer_from_hf() {
// Test with a small model that should have tokenizer files
// Skip this test if HF_TOKEN is not set and we're in CI
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
println!("Skipping HF download test in CI without HF_TOKEN");
return;
}
// Try to create tokenizer for a known small model
let result = create_tokenizer_async("bert-base-uncased").await;
// The test might fail due to network issues or rate limiting
// so we just check that the function executes without panic
match result {
Ok(tokenizer) => {
assert!(tokenizer.vocab_size() > 0);
println!("Successfully downloaded and created tokenizer");
}
Err(e) => {
println!("Download failed (this might be expected): {}", e);
// Don't fail the test - network issues shouldn't break CI
}
}
}
}