[router] add tokenizer download support from hf hub (#9882)
This commit is contained in:
@@ -5,15 +5,15 @@ use std::io::Read;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "huggingface")]
|
||||
use super::huggingface::HuggingFaceTokenizer;
|
||||
use super::tiktoken::TiktokenTokenizer;
|
||||
use crate::tokenizer::hub::download_tokenizer_from_hf;
|
||||
|
||||
/// Represents the type of tokenizer being used
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TokenizerType {
|
||||
HuggingFace(String),
|
||||
Mock,
|
||||
#[cfg(feature = "tiktoken")]
|
||||
Tiktoken(String),
|
||||
// Future: SentencePiece, GGUF
|
||||
}
|
||||
@@ -52,21 +52,10 @@ pub fn create_tokenizer_with_chat_template(
|
||||
|
||||
let result = match extension.as_deref() {
|
||||
Some("json") => {
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
file_path,
|
||||
chat_template_path,
|
||||
)?;
|
||||
let tokenizer =
|
||||
HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
|
||||
|
||||
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
|
||||
}
|
||||
#[cfg(not(feature = "huggingface"))]
|
||||
{
|
||||
Err(Error::msg(
|
||||
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
|
||||
))
|
||||
}
|
||||
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
|
||||
}
|
||||
Some("model") => {
|
||||
// SentencePiece model file
|
||||
@@ -94,17 +83,8 @@ fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>>
|
||||
|
||||
// Check for JSON (HuggingFace format)
|
||||
if is_likely_json(&buffer) {
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
#[cfg(not(feature = "huggingface"))]
|
||||
{
|
||||
return Err(Error::msg(
|
||||
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
|
||||
));
|
||||
}
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
|
||||
// Check for GGUF magic number
|
||||
@@ -154,7 +134,57 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|
||||
|| buffer.windows(4).any(|w| w == b"</s>"))
|
||||
}
|
||||
|
||||
/// Factory function to create tokenizer from a model name or path
|
||||
/// Factory function to create tokenizer from a model name or path (async version)
|
||||
pub async fn create_tokenizer_async(
|
||||
model_name_or_path: &str,
|
||||
) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Check if it's a file path
|
||||
let path = Path::new(model_name_or_path);
|
||||
if path.exists() {
|
||||
return create_tokenizer_from_file(model_name_or_path);
|
||||
}
|
||||
|
||||
// Check if it's a GPT model name that should use Tiktoken
|
||||
if model_name_or_path.contains("gpt-")
|
||||
|| model_name_or_path.contains("davinci")
|
||||
|| model_name_or_path.contains("curie")
|
||||
|| model_name_or_path.contains("babbage")
|
||||
|| model_name_or_path.contains("ada")
|
||||
{
|
||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
|
||||
// Try to download tokenizer files from HuggingFace
|
||||
match download_tokenizer_from_hf(model_name_or_path).await {
|
||||
Ok(cache_dir) => {
|
||||
// Look for tokenizer.json in the cache directory
|
||||
let tokenizer_path = cache_dir.join("tokenizer.json");
|
||||
if tokenizer_path.exists() {
|
||||
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
|
||||
} else {
|
||||
// Try other common tokenizer file names
|
||||
let possible_files = ["tokenizer_config.json", "vocab.json"];
|
||||
for file_name in &possible_files {
|
||||
let file_path = cache_dir.join(file_name);
|
||||
if file_path.exists() {
|
||||
return create_tokenizer_from_file(file_path.to_str().unwrap());
|
||||
}
|
||||
}
|
||||
Err(Error::msg(format!(
|
||||
"Downloaded model '{}' but couldn't find a suitable tokenizer file",
|
||||
model_name_or_path
|
||||
)))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(Error::msg(format!(
|
||||
"Failed to download tokenizer from HuggingFace: {}",
|
||||
e
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory function to create tokenizer from a model name or path (blocking version)
|
||||
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
|
||||
// Check if it's a file path
|
||||
let path = Path::new(model_name_or_path);
|
||||
@@ -163,35 +193,25 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
|
||||
}
|
||||
|
||||
// Check if it's a GPT model name that should use Tiktoken
|
||||
#[cfg(feature = "tiktoken")]
|
||||
if model_name_or_path.contains("gpt-")
|
||||
|| model_name_or_path.contains("davinci")
|
||||
|| model_name_or_path.contains("curie")
|
||||
|| model_name_or_path.contains("babbage")
|
||||
|| model_name_or_path.contains("ada")
|
||||
{
|
||||
if model_name_or_path.contains("gpt-")
|
||||
|| model_name_or_path.contains("davinci")
|
||||
|| model_name_or_path.contains("curie")
|
||||
|| model_name_or_path.contains("babbage")
|
||||
|| model_name_or_path.contains("ada")
|
||||
{
|
||||
use super::tiktoken::TiktokenTokenizer;
|
||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
|
||||
// Otherwise, try to load from HuggingFace Hub
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
// This would download from HF Hub - not implemented yet
|
||||
Err(Error::msg(
|
||||
"Loading from HuggingFace Hub not yet implemented",
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "huggingface"))]
|
||||
{
|
||||
Err(Error::msg(format!(
|
||||
"Model '{}' not found locally and HuggingFace support is not enabled",
|
||||
model_name_or_path
|
||||
)))
|
||||
// Only use tokio for HuggingFace downloads
|
||||
// Check if we're already in a tokio runtime
|
||||
if let Ok(handle) = tokio::runtime::Handle::try_current() {
|
||||
// We're in a runtime, use block_in_place
|
||||
tokio::task::block_in_place(|| handle.block_on(create_tokenizer_async(model_name_or_path)))
|
||||
} else {
|
||||
// No runtime, create a temporary one
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(create_tokenizer_async(model_name_or_path))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +277,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tiktoken")]
|
||||
#[test]
|
||||
fn test_create_tiktoken_tokenizer() {
|
||||
// Test creating tokenizer for GPT models
|
||||
@@ -270,4 +289,30 @@ mod tests {
|
||||
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
|
||||
assert_eq!(decoded, text);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_download_tokenizer_from_hf() {
|
||||
// Test with a small model that should have tokenizer files
|
||||
// Skip this test if HF_TOKEN is not set and we're in CI
|
||||
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
|
||||
println!("Skipping HF download test in CI without HF_TOKEN");
|
||||
return;
|
||||
}
|
||||
|
||||
// Try to create tokenizer for a known small model
|
||||
let result = create_tokenizer_async("bert-base-uncased").await;
|
||||
|
||||
// The test might fail due to network issues or rate limiting
|
||||
// so we just check that the function executes without panic
|
||||
match result {
|
||||
Ok(tokenizer) => {
|
||||
assert!(tokenizer.vocab_size() > 0);
|
||||
println!("Successfully downloaded and created tokenizer");
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Download failed (this might be expected): {}", e);
|
||||
// Don't fail the test - network issues shouldn't break CI
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user