[router] add tiktokenizer and sequence in router (#9354)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -6,6 +6,7 @@ edition = "2021"
|
|||||||
[features]
|
[features]
|
||||||
default = ["huggingface"]
|
default = ["huggingface"]
|
||||||
huggingface = ["tokenizers"]
|
huggingface = ["tokenizers"]
|
||||||
|
tiktoken = ["tiktoken-rs"]
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
name = "sglang_router_rs"
|
name = "sglang_router_rs"
|
||||||
@@ -49,6 +50,7 @@ url = "2.5.4"
|
|||||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
tokenizers = { version = "0.21.4", optional = true }
|
tokenizers = { version = "0.21.4", optional = true }
|
||||||
|
tiktoken-rs = { version = "0.5", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = { version = "0.5", features = ["html_reports"] }
|
criterion = { version = "0.5", features = ["html_reports"] }
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
use super::{traits, TokenizerTrait};
|
use super::traits::{self, Tokenizer as TokenizerTrait};
|
||||||
use crate::metrics::TokenizerMetrics;
|
use crate::metrics::TokenizerMetrics;
|
||||||
use anyhow::{Error, Result};
|
use anyhow::{Error, Result};
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
@@ -15,7 +15,9 @@ use super::huggingface::HuggingFaceTokenizer;
|
|||||||
pub enum TokenizerType {
|
pub enum TokenizerType {
|
||||||
HuggingFace(String),
|
HuggingFace(String),
|
||||||
Mock,
|
Mock,
|
||||||
// Future: SentencePiece, GGUF, Tiktoken
|
#[cfg(feature = "tiktoken")]
|
||||||
|
Tiktoken(String),
|
||||||
|
// Future: SentencePiece, GGUF
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a tokenizer from a file path to a tokenizer file.
|
/// Create a tokenizer from a file path to a tokenizer file.
|
||||||
@@ -166,6 +168,23 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
|
|||||||
return create_tokenizer_from_file(model_name_or_path);
|
return create_tokenizer_from_file(model_name_or_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if it's a GPT model name that should use Tiktoken
|
||||||
|
#[cfg(feature = "tiktoken")]
|
||||||
|
{
|
||||||
|
if model_name_or_path.contains("gpt-")
|
||||||
|
|| model_name_or_path.contains("davinci")
|
||||||
|
|| model_name_or_path.contains("curie")
|
||||||
|
|| model_name_or_path.contains("babbage")
|
||||||
|
|| model_name_or_path.contains("ada")
|
||||||
|
{
|
||||||
|
use super::tiktoken::TiktokenTokenizer;
|
||||||
|
let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
|
||||||
|
TokenizerMetrics::record_factory_load("tiktoken");
|
||||||
|
TokenizerMetrics::set_vocab_size("tiktoken", tokenizer.vocab_size());
|
||||||
|
return Ok(Arc::new(tokenizer));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Otherwise, try to load from HuggingFace Hub
|
// Otherwise, try to load from HuggingFace Hub
|
||||||
#[cfg(feature = "huggingface")]
|
#[cfg(feature = "huggingface")]
|
||||||
{
|
{
|
||||||
@@ -245,4 +264,18 @@ mod tests {
|
|||||||
assert!(e.to_string().contains("File not found"));
|
assert!(e.to_string().contains("File not found"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "tiktoken")]
|
||||||
|
#[test]
|
||||||
|
fn test_create_tiktoken_tokenizer() {
|
||||||
|
// Test creating tokenizer for GPT models
|
||||||
|
let tokenizer = create_tokenizer("gpt-4").unwrap();
|
||||||
|
assert!(tokenizer.vocab_size() > 0);
|
||||||
|
|
||||||
|
// Test encoding and decoding
|
||||||
|
let text = "Hello, world!";
|
||||||
|
let encoding = tokenizer.encode(text).unwrap();
|
||||||
|
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap();
|
||||||
|
assert_eq!(decoded, text);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ use std::sync::Arc;
|
|||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod mock;
|
pub mod mock;
|
||||||
|
pub mod sequence;
|
||||||
pub mod stop;
|
pub mod stop;
|
||||||
pub mod stream;
|
pub mod stream;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
@@ -12,11 +13,15 @@ pub mod traits;
|
|||||||
#[cfg(feature = "huggingface")]
|
#[cfg(feature = "huggingface")]
|
||||||
pub mod huggingface;
|
pub mod huggingface;
|
||||||
|
|
||||||
|
#[cfg(feature = "tiktoken")]
|
||||||
|
pub mod tiktoken;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
// Re-exports
|
// Re-exports
|
||||||
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
|
||||||
|
pub use sequence::Sequence;
|
||||||
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
|
||||||
pub use stream::DecodeStream;
|
pub use stream::DecodeStream;
|
||||||
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||||
@@ -24,6 +29,9 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
|
|||||||
#[cfg(feature = "huggingface")]
|
#[cfg(feature = "huggingface")]
|
||||||
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
|
||||||
|
|
||||||
|
#[cfg(feature = "tiktoken")]
|
||||||
|
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
|
||||||
|
|
||||||
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
|
||||||
|
|||||||
238
sgl-router/src/tokenizer/sequence.rs
Normal file
238
sgl-router/src/tokenizer/sequence.rs
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
use super::traits::Tokenizer as TokenizerTrait;
|
||||||
|
use anyhow::Result;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Maintains state for an ongoing sequence of tokens and their decoded text
|
||||||
|
/// This provides a cleaner abstraction for managing token sequences
|
||||||
|
pub struct Sequence {
|
||||||
|
/// The tokenizer used for encoding/decoding
|
||||||
|
tokenizer: Arc<dyn TokenizerTrait>,
|
||||||
|
|
||||||
|
/// The current sequence of token ids
|
||||||
|
token_ids: Vec<u32>,
|
||||||
|
|
||||||
|
/// The position in the current sequence the last decoded token completed
|
||||||
|
prefix_offset: usize,
|
||||||
|
|
||||||
|
/// Current position in the sequence
|
||||||
|
read_offset: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Sequence {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("Sequence")
|
||||||
|
.field("tokenizer", &"Arc<dyn Tokenizer>")
|
||||||
|
.field(
|
||||||
|
"token_ids",
|
||||||
|
&format_args!("{}", {
|
||||||
|
let token_ids = self.token_ids();
|
||||||
|
if token_ids.len() <= 20 {
|
||||||
|
format!("{:?}", token_ids)
|
||||||
|
} else {
|
||||||
|
let first_ten = &token_ids[..10];
|
||||||
|
let last_ten = &token_ids[token_ids.len() - 10..];
|
||||||
|
format!("{:?} ... {:?}", first_ten, last_ten)
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.field("prefix_offset", &self.prefix_offset)
|
||||||
|
.field("read_offset", &self.read_offset)
|
||||||
|
.field("token count", &self.token_ids.len())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sequence {
|
||||||
|
/// Create a new empty sequence
|
||||||
|
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
|
||||||
|
Self {
|
||||||
|
tokenizer,
|
||||||
|
token_ids: Vec::new(),
|
||||||
|
prefix_offset: 0,
|
||||||
|
read_offset: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a sequence with initial tokens
|
||||||
|
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self {
|
||||||
|
let len = token_ids.len();
|
||||||
|
Self {
|
||||||
|
tokenizer,
|
||||||
|
token_ids,
|
||||||
|
prefix_offset: 0,
|
||||||
|
read_offset: len,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the sequence is empty
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.token_ids.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the length of the sequence
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.token_ids.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clear the sequence
|
||||||
|
pub fn clear(&mut self) {
|
||||||
|
self.token_ids.clear();
|
||||||
|
self.prefix_offset = 0;
|
||||||
|
self.read_offset = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append text to the sequence by encoding it
|
||||||
|
pub fn append_text(&mut self, input: &str) -> Result<()> {
|
||||||
|
let encoding = self.tokenizer.encode(input)?;
|
||||||
|
self.token_ids.extend(encoding.token_ids());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append a single token to the sequence and return newly decoded text
|
||||||
|
/// Based on HuggingFace TGI incremental decoding
|
||||||
|
pub fn append_token(&mut self, token_id: u32) -> Result<String> {
|
||||||
|
// Store the old read offset before adding the new token
|
||||||
|
let old_read_offset = self.read_offset;
|
||||||
|
|
||||||
|
self.token_ids.push(token_id);
|
||||||
|
self.read_offset = self.token_ids.len();
|
||||||
|
|
||||||
|
// If this is the first token or we're at the beginning, decode everything
|
||||||
|
if self.prefix_offset == 0 && old_read_offset == 0 {
|
||||||
|
let text = self.tokenizer.decode(&self.token_ids, false)?;
|
||||||
|
if text.ends_with("<EFBFBD>") {
|
||||||
|
// Incomplete UTF-8 sequence, wait for more tokens
|
||||||
|
return Ok(String::new());
|
||||||
|
}
|
||||||
|
self.prefix_offset = 0;
|
||||||
|
return Ok(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the text up to the previous position
|
||||||
|
let prefix_text = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?;
|
||||||
|
|
||||||
|
// Decode the text including the new token
|
||||||
|
let new_text = self
|
||||||
|
.tokenizer
|
||||||
|
.decode(&self.token_ids[self.prefix_offset..], false)?;
|
||||||
|
|
||||||
|
// Handle multi-byte character boundaries
|
||||||
|
let mut prefix_text_len = prefix_text.len();
|
||||||
|
while !new_text.is_char_boundary(prefix_text_len) && prefix_text_len > 0 {
|
||||||
|
prefix_text_len -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if new_text.len() > prefix_text.len() {
|
||||||
|
if new_text.ends_with("<EFBFBD>") {
|
||||||
|
// Incomplete UTF-8 sequence, wait for more tokens
|
||||||
|
return Ok(String::new());
|
||||||
|
} else {
|
||||||
|
// Return the new text portion
|
||||||
|
let incremental_text = new_text[prefix_text_len..].to_string().replace("<EFBFBD>", "");
|
||||||
|
self.prefix_offset = old_read_offset;
|
||||||
|
return Ok(incremental_text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(String::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a reference to the tokenizer
|
||||||
|
pub fn tokenizer(&self) -> &Arc<dyn TokenizerTrait> {
|
||||||
|
&self.tokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the current token ids
|
||||||
|
pub fn token_ids(&self) -> &[u32] {
|
||||||
|
&self.token_ids
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode the entire sequence to text
|
||||||
|
pub fn text(&self) -> Result<String> {
|
||||||
|
self.tokenizer.decode(&self.token_ids, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the prefix offset
|
||||||
|
pub fn prefix_offset(&self) -> usize {
|
||||||
|
self.prefix_offset
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the read offset
|
||||||
|
pub fn read_offset(&self) -> usize {
|
||||||
|
self.read_offset
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::tokenizer::mock::MockTokenizer;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sequence_new() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let seq = Sequence::new(tokenizer);
|
||||||
|
assert!(seq.is_empty());
|
||||||
|
assert_eq!(seq.len(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sequence_append_text() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let mut seq = Sequence::new(tokenizer);
|
||||||
|
|
||||||
|
seq.append_text("Hello").unwrap();
|
||||||
|
assert!(!seq.is_empty());
|
||||||
|
assert!(!seq.is_empty());
|
||||||
|
|
||||||
|
let text = seq.text().unwrap();
|
||||||
|
assert_eq!(text, "Hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sequence_append_token() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let mut seq = Sequence::new(tokenizer.clone());
|
||||||
|
|
||||||
|
// Start with an empty sequence and append token 1 ("Hello")
|
||||||
|
let text1 = seq.append_token(1).unwrap();
|
||||||
|
assert_eq!(text1, "Hello");
|
||||||
|
|
||||||
|
// Now append token 2 ("world")
|
||||||
|
// The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
|
||||||
|
let text2 = seq.append_token(2).unwrap();
|
||||||
|
// The incremental text should be " world" (with the space that the mock tokenizer adds)
|
||||||
|
assert_eq!(text2, " world");
|
||||||
|
|
||||||
|
// Verify the full text
|
||||||
|
assert_eq!(seq.text().unwrap(), "Hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sequence_clear() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let mut seq = Sequence::new(tokenizer);
|
||||||
|
|
||||||
|
seq.append_text("Hello world").unwrap();
|
||||||
|
assert!(!seq.is_empty());
|
||||||
|
|
||||||
|
seq.clear();
|
||||||
|
assert!(seq.is_empty());
|
||||||
|
assert_eq!(seq.len(), 0);
|
||||||
|
assert_eq!(seq.prefix_offset(), 0);
|
||||||
|
assert_eq!(seq.read_offset(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sequence_debug() {
|
||||||
|
let tokenizer = Arc::new(MockTokenizer::new());
|
||||||
|
let mut seq = Sequence::new(tokenizer);
|
||||||
|
|
||||||
|
seq.append_text("Test").unwrap();
|
||||||
|
let debug_str = format!("{:?}", seq);
|
||||||
|
assert!(debug_str.contains("Sequence"));
|
||||||
|
assert!(debug_str.contains("token count"));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -129,7 +129,9 @@ fn test_thread_safety() {
|
|||||||
thread::spawn(move || {
|
thread::spawn(move || {
|
||||||
let text = "Hello test".to_string();
|
let text = "Hello test".to_string();
|
||||||
let encoding = tokenizer_clone.encode(&text).unwrap();
|
let encoding = tokenizer_clone.encode(&text).unwrap();
|
||||||
let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
|
let decoded = tokenizer_clone
|
||||||
|
.decode(&encoding.token_ids(), false)
|
||||||
|
.unwrap();
|
||||||
assert!(decoded.contains("Hello") || decoded.contains("test"));
|
assert!(decoded.contains("Hello") || decoded.contains("test"));
|
||||||
i
|
i
|
||||||
})
|
})
|
||||||
|
|||||||
276
sgl-router/src/tokenizer/tiktoken.rs
Normal file
276
sgl-router/src/tokenizer/tiktoken.rs
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
|
||||||
|
use anyhow::{Error, Result};
|
||||||
|
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
|
||||||
|
|
||||||
|
/// Tiktoken tokenizer wrapper for OpenAI GPT models
|
||||||
|
pub struct TiktokenTokenizer {
|
||||||
|
tokenizer: CoreBPE,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
model: TiktokenModel,
|
||||||
|
special_tokens: SpecialTokens,
|
||||||
|
vocab_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Supported Tiktoken models
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum TiktokenModel {
|
||||||
|
/// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
|
||||||
|
Cl100kBase,
|
||||||
|
/// Codex models, text-davinci-002, text-davinci-003
|
||||||
|
P50kBase,
|
||||||
|
/// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
|
||||||
|
P50kEdit,
|
||||||
|
/// GPT-3 models like davinci
|
||||||
|
R50kBase,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TiktokenTokenizer {
|
||||||
|
/// Create a new Tiktoken tokenizer for the specified model
|
||||||
|
pub fn new(model: TiktokenModel) -> Result<Self> {
|
||||||
|
let tokenizer =
|
||||||
|
match model {
|
||||||
|
TiktokenModel::Cl100kBase => cl100k_base()
|
||||||
|
.map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?,
|
||||||
|
TiktokenModel::P50kBase => p50k_base()
|
||||||
|
.map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?,
|
||||||
|
TiktokenModel::P50kEdit => p50k_edit()
|
||||||
|
.map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?,
|
||||||
|
TiktokenModel::R50kBase => r50k_base()
|
||||||
|
.map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract special tokens (tiktoken-rs doesn't expose them directly)
|
||||||
|
// We'll use common ones for GPT models
|
||||||
|
let special_tokens = Self::get_special_tokens_for_model(model);
|
||||||
|
|
||||||
|
// Get vocabulary size (this is an approximation)
|
||||||
|
let vocab_size = match model {
|
||||||
|
TiktokenModel::Cl100kBase => 100256, // cl100k has ~100k tokens
|
||||||
|
TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, // p50k has ~50k tokens
|
||||||
|
TiktokenModel::R50kBase => 50257, // r50k has ~50k tokens
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(TiktokenTokenizer {
|
||||||
|
tokenizer,
|
||||||
|
model,
|
||||||
|
special_tokens,
|
||||||
|
vocab_size,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
|
||||||
|
pub fn from_model_name(model_name: &str) -> Result<Self> {
|
||||||
|
let model = Self::model_from_name(model_name)?;
|
||||||
|
Self::new(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determine the appropriate model from a model name
|
||||||
|
fn model_from_name(model_name: &str) -> Result<TiktokenModel> {
|
||||||
|
// Based on OpenAI's model-to-encoding mapping
|
||||||
|
if model_name.contains("gpt-4")
|
||||||
|
|| model_name.contains("gpt-3.5")
|
||||||
|
|| model_name.contains("turbo")
|
||||||
|
{
|
||||||
|
Ok(TiktokenModel::Cl100kBase)
|
||||||
|
} else if model_name.contains("davinci-002")
|
||||||
|
|| model_name.contains("davinci-003")
|
||||||
|
|| model_name.contains("codex")
|
||||||
|
{
|
||||||
|
Ok(TiktokenModel::P50kBase)
|
||||||
|
} else if model_name.contains("edit") {
|
||||||
|
Ok(TiktokenModel::P50kEdit)
|
||||||
|
} else if model_name.contains("davinci")
|
||||||
|
|| model_name.contains("curie")
|
||||||
|
|| model_name.contains("babbage")
|
||||||
|
|| model_name.contains("ada")
|
||||||
|
{
|
||||||
|
Ok(TiktokenModel::R50kBase)
|
||||||
|
} else {
|
||||||
|
// Return an error for unrecognized model names to prevent silent failures
|
||||||
|
Err(anyhow::anyhow!(
|
||||||
|
"Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names",
|
||||||
|
model_name
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get special tokens for a specific model
|
||||||
|
fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens {
|
||||||
|
// These are common special tokens for GPT models
|
||||||
|
// The actual token IDs might vary by model
|
||||||
|
match model {
|
||||||
|
TiktokenModel::Cl100kBase => SpecialTokens {
|
||||||
|
bos_token: Some("<|endoftext|>".to_string()),
|
||||||
|
eos_token: Some("<|endoftext|>".to_string()),
|
||||||
|
unk_token: None,
|
||||||
|
sep_token: None,
|
||||||
|
pad_token: Some("<|endoftext|>".to_string()),
|
||||||
|
cls_token: None,
|
||||||
|
mask_token: None,
|
||||||
|
additional_special_tokens: vec![
|
||||||
|
"<|fim_prefix|>".to_string(),
|
||||||
|
"<|fim_middle|>".to_string(),
|
||||||
|
"<|fim_suffix|>".to_string(),
|
||||||
|
"<|endofprompt|>".to_string(),
|
||||||
|
],
|
||||||
|
},
|
||||||
|
_ => SpecialTokens {
|
||||||
|
bos_token: Some("<|endoftext|>".to_string()),
|
||||||
|
eos_token: Some("<|endoftext|>".to_string()),
|
||||||
|
unk_token: None,
|
||||||
|
sep_token: None,
|
||||||
|
pad_token: Some("<|endoftext|>".to_string()),
|
||||||
|
cls_token: None,
|
||||||
|
mask_token: None,
|
||||||
|
additional_special_tokens: vec![],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder for TiktokenTokenizer {
|
||||||
|
fn encode(&self, input: &str) -> Result<Encoding> {
|
||||||
|
let tokens = self.tokenizer.encode_ordinary(input);
|
||||||
|
Ok(Encoding::Tiktoken(tokens))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
|
||||||
|
inputs.iter().map(|input| self.encode(input)).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder for TiktokenTokenizer {
|
||||||
|
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> {
|
||||||
|
// Convert u32 to usize for tiktoken-rs
|
||||||
|
let tokens: Vec<usize> = token_ids.iter().map(|&id| id as usize).collect();
|
||||||
|
|
||||||
|
self.tokenizer
|
||||||
|
.decode(tokens)
|
||||||
|
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenizerTrait for TiktokenTokenizer {
|
||||||
|
fn vocab_size(&self) -> usize {
|
||||||
|
self.vocab_size
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_special_tokens(&self) -> &SpecialTokens {
|
||||||
|
&self.special_tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_to_id(&self, _token: &str) -> Option<u32> {
|
||||||
|
// Tiktoken doesn't provide direct token-to-id mapping
|
||||||
|
// We'd need to encode the token and check if it produces a single ID
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn id_to_token(&self, _id: u32) -> Option<String> {
|
||||||
|
// Tiktoken doesn't provide direct id-to-token mapping
|
||||||
|
// We can only decode IDs to text
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tiktoken_creation() {
|
||||||
|
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||||
|
assert_eq!(tokenizer.vocab_size(), 100256);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_from_name() {
|
||||||
|
assert!(matches!(
|
||||||
|
TiktokenTokenizer::model_from_name("gpt-4").unwrap(),
|
||||||
|
TiktokenModel::Cl100kBase
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(),
|
||||||
|
TiktokenModel::Cl100kBase
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(),
|
||||||
|
TiktokenModel::P50kBase
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(),
|
||||||
|
TiktokenModel::P50kEdit
|
||||||
|
));
|
||||||
|
assert!(matches!(
|
||||||
|
TiktokenTokenizer::model_from_name("davinci").unwrap(),
|
||||||
|
TiktokenModel::R50kBase
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_encode_decode() {
|
||||||
|
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||||
|
|
||||||
|
let text = "Hello, world!";
|
||||||
|
let encoding = tokenizer.encode(text).unwrap();
|
||||||
|
|
||||||
|
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap();
|
||||||
|
assert_eq!(decoded, text);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_batch_encode() {
|
||||||
|
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||||
|
|
||||||
|
let texts = vec!["Hello", "World", "Test"];
|
||||||
|
let encodings = tokenizer.encode_batch(&texts).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(encodings.len(), 3);
|
||||||
|
for (i, encoding) in encodings.iter().enumerate() {
|
||||||
|
let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap();
|
||||||
|
assert_eq!(decoded, texts[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_special_tokens() {
|
||||||
|
let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap();
|
||||||
|
let special_tokens = tokenizer.get_special_tokens();
|
||||||
|
|
||||||
|
assert!(special_tokens.eos_token.is_some());
|
||||||
|
assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unrecognized_model_name_returns_error() {
|
||||||
|
// Test that unrecognized model names return an error
|
||||||
|
let result = TiktokenTokenizer::from_model_name("distilgpt-2");
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = TiktokenTokenizer::from_model_name("bert-base-uncased");
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = TiktokenTokenizer::from_model_name("llama-7b");
|
||||||
|
assert!(result.is_err());
|
||||||
|
if let Err(e) = result {
|
||||||
|
assert!(e.to_string().contains("Unrecognized OpenAI model name"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_recognized_model_names() {
|
||||||
|
// Test that recognized model names work correctly
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok());
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok());
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok());
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok());
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok());
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok());
|
||||||
|
assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,13 +26,28 @@ pub enum Encoding {
|
|||||||
Hf(Box<tokenizers::tokenizer::Encoding>),
|
Hf(Box<tokenizers::tokenizer::Encoding>),
|
||||||
/// Sentence Piece
|
/// Sentence Piece
|
||||||
Sp(Vec<u32>),
|
Sp(Vec<u32>),
|
||||||
|
/// Tiktoken (for GPT models)
|
||||||
|
Tiktoken(Vec<usize>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Encoding {
|
impl Encoding {
|
||||||
pub fn token_ids(&self) -> &[u32] {
|
pub fn token_ids(&self) -> Vec<u32> {
|
||||||
|
match self {
|
||||||
|
Encoding::Hf(inner) => inner.get_ids().to_vec(),
|
||||||
|
Encoding::Sp(inner) => inner.clone(),
|
||||||
|
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn token_ids_ref(&self) -> &[u32] {
|
||||||
match self {
|
match self {
|
||||||
Encoding::Hf(inner) => inner.get_ids(),
|
Encoding::Hf(inner) => inner.get_ids(),
|
||||||
Encoding::Sp(inner) => inner,
|
Encoding::Sp(inner) => inner,
|
||||||
|
Encoding::Tiktoken(_) => {
|
||||||
|
// Tiktoken uses usize, we can't return a reference to u32
|
||||||
|
// This is a limitation - callers should use token_ids() for Tiktoken
|
||||||
|
&[]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user