[router] add tiktokenizer and sequence in router (#9354)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use super::{traits, TokenizerTrait};
|
||||
use super::traits::{self, Tokenizer as TokenizerTrait};
|
||||
use crate::metrics::TokenizerMetrics;
|
||||
use anyhow::{Error, Result};
|
||||
use std::fs::File;
|
||||
@@ -15,7 +15,9 @@ use super::huggingface::HuggingFaceTokenizer;
|
||||
pub enum TokenizerType {
|
||||
HuggingFace(String),
|
||||
Mock,
|
||||
// Future: SentencePiece, GGUF, Tiktoken
|
||||
#[cfg(feature = "tiktoken")]
|
||||
Tiktoken(String),
|
||||
// Future: SentencePiece, GGUF
|
||||
}
|
||||
|
||||
/// Create a tokenizer from a file path to a tokenizer file.
|
||||
@@ -166,6 +168,23 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
|
||||
return create_tokenizer_from_file(model_name_or_path);
|
||||
}
|
||||
|
||||
// Check if it's a GPT model name that should use Tiktoken
|
||||
#[cfg(feature = "tiktoken")]
|
||||
{
|
||||
if model_name_or_path.contains("gpt-")
|
||||
|| model_name_or_path.contains("davinci")
|
||||
|| model_name_or_path.contains("curie")
|
||||
|| model_name_or_path.contains("babbage")
|
||||
|| model_name_or_path.contains("ada")
|
||||
{
|
||||
use super::tiktoken::TiktokenTokenizer;
|
||||
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||
TokenizerMetrics::record_factory_load("tiktoken");
|
||||
TokenizerMetrics::set_vocab_size("tiktoken", tokenizer.vocab_size());
|
||||
return Ok(Arc::new(tokenizer));
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise, try to load from HuggingFace Hub
|
||||
#[cfg(feature = "huggingface")]
|
||||
{
|
||||
@@ -245,4 +264,18 @@ mod tests {
|
||||
assert!(e.to_string().contains("File not found"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tiktoken")]
|
||||
#[test]
|
||||
fn test_create_tiktoken_tokenizer() {
|
||||
// Test creating tokenizer for GPT models
|
||||
let tokenizer = create_tokenizer("gpt-4").unwrap();
|
||||
assert!(tokenizer.vocab_size() > 0);
|
||||
|
||||
// Test encoding and decoding
|
||||
let text = "Hello, world!";
|
||||
let encoding = tokenizer.encode(text).unwrap();
|
||||
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap();
|
||||
assert_eq!(decoded, text);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user