[router] fix chat template loading and tokenizer path (#10999)

This commit is contained in:
Simo Lin
2025-09-27 23:54:12 -04:00
committed by GitHub
parent 72392f2908
commit 5519766a4d
5 changed files with 418 additions and 9 deletions

View File

@@ -134,6 +134,36 @@ fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
|| buffer.windows(4).any(|w| w == b"</s>"))
}
/// Helper function to discover chat template files in a directory
pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
use std::fs;
// Priority 1: Look for chat_template.json (contains Jinja in JSON format)
let json_template_path = dir.join("chat_template.json");
if json_template_path.exists() {
return json_template_path.to_str().map(|s| s.to_string());
}
// Priority 2: Look for chat_template.jinja (standard Jinja file)
let jinja_path = dir.join("chat_template.jinja");
if jinja_path.exists() {
return jinja_path.to_str().map(|s| s.to_string());
}
// Priority 3: Look for any .jinja file (for models with non-standard naming)
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".jinja") && name != "chat_template.jinja" {
return entry.path().to_str().map(|s| s.to_string());
}
}
}
}
None
}
/// Factory function to create tokenizer from a model name or path (async version)
pub async fn create_tokenizer_async(
model_name_or_path: &str,
@@ -161,14 +191,32 @@ pub async fn create_tokenizer_async(
// Look for tokenizer.json in the cache directory
let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() {
create_tokenizer_from_file(tokenizer_path.to_str().unwrap())
// Try to find a chat template file in the cache directory
let chat_template_path = discover_chat_template_in_dir(&cache_dir);
let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
Error::msg(format!(
"Tokenizer path is not valid UTF-8: {:?}",
tokenizer_path
))
})?;
create_tokenizer_with_chat_template(
tokenizer_path_str,
chat_template_path.as_deref(),
)
} else {
// Try other common tokenizer file names
let possible_files = ["tokenizer_config.json", "vocab.json"];
for file_name in &possible_files {
let file_path = cache_dir.join(file_name);
if file_path.exists() {
return create_tokenizer_from_file(file_path.to_str().unwrap());
let chat_template_path = discover_chat_template_in_dir(&cache_dir);
let file_path_str = file_path.to_str().ok_or_else(|| {
Error::msg(format!("File path is not valid UTF-8: {:?}", file_path))
})?;
return create_tokenizer_with_chat_template(
file_path_str,
chat_template_path.as_deref(),
);
}
}
Err(Error::msg(format!(

View File

@@ -40,6 +40,13 @@ fn is_tokenizer_file(filename: &str) -> bool {
|| filename.ends_with("merges.txt")
|| filename.ends_with(".model") // SentencePiece models
|| filename.ends_with(".tiktoken")
|| is_chat_template_file(filename) // Include chat template files
}
/// Checks if a file is a chat template file
fn is_chat_template_file(filename: &str) -> bool {
filename.ends_with(".jinja") // Direct Jinja files
|| filename == "chat_template.json" // JSON file containing Jinja template
}
/// Attempt to download tokenizer files from Hugging Face
@@ -123,7 +130,13 @@ pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::R
}
match cache_dir {
Some(dir) => Ok(dir),
Some(dir) => {
// Ensure we return the correct model directory, not a subfolder
// Some models have an "original" subfolder for PyTorch weights
// We want the main model directory that contains tokenizer files
let final_dir = resolve_model_cache_dir(&dir, &model_name);
Ok(final_dir)
}
None => Err(anyhow::anyhow!(
"Invalid HF cache path for model '{}'",
model_name
@@ -206,11 +219,76 @@ pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Re
}
match p.parent() {
Some(p) => Ok(p.to_path_buf()),
Some(p) => {
let final_dir = resolve_model_cache_dir(p, &model_name);
Ok(final_dir)
}
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}
/// Resolve the correct model cache directory
/// Handles cases where files might be in subfolders (e.g., "original" folder)
fn resolve_model_cache_dir(path: &Path, model_name: &str) -> PathBuf {
// Check if we're in a subfolder like "original"
if let Some(parent) = path.parent() {
if let Some(folder_name) = path.file_name() {
if folder_name == "original" {
// We're in the "original" subfolder, go up one level
return parent.to_path_buf();
}
}
}
// Check if the current path contains the model name components
// This helps ensure we're at the right directory level
let model_parts: Vec<&str> = model_name.split('/').collect();
if model_parts.len() >= 2 {
let expected_pattern = format!(
"models--{}--{}",
model_parts[0].replace("-", "--"),
model_parts[1].replace("-", "--")
);
if path.to_string_lossy().contains(&expected_pattern) {
// We're already at the correct level
return path.to_path_buf();
}
let mut current = path.to_path_buf();
// First check if current path already contains tokenizer files
if current.join("tokenizer.json").exists() || current.join("tokenizer_config.json").exists()
{
return current;
}
// If not, traverse up to find the model root, then look in snapshots
while let Some(parent) = current.parent() {
if parent.to_string_lossy().contains(&expected_pattern) {
let snapshots_dir = parent.join("snapshots");
if snapshots_dir.exists() && snapshots_dir.is_dir() {
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
for entry in entries.flatten() {
let snapshot_path = entry.path();
if snapshot_path.is_dir()
&& (snapshot_path.join("tokenizer.json").exists()
|| snapshot_path.join("tokenizer_config.json").exists())
{
return snapshot_path;
}
}
}
}
return parent.to_path_buf();
}
current = parent.to_path_buf();
}
}
path.to_path_buf()
}
#[cfg(test)]
mod tests {
use super::*;
@@ -223,10 +301,24 @@ mod tests {
assert!(is_tokenizer_file("vocab.json"));
assert!(is_tokenizer_file("merges.txt"));
assert!(is_tokenizer_file("spiece.model"));
assert!(is_tokenizer_file("chat_template.jinja"));
assert!(is_tokenizer_file("template.jinja"));
assert!(!is_tokenizer_file("model.bin"));
assert!(!is_tokenizer_file("README.md"));
}
#[test]
fn test_is_chat_template_file() {
assert!(is_chat_template_file("chat_template.jinja"));
assert!(is_chat_template_file("template.jinja"));
assert!(is_chat_template_file("any_file.jinja"));
assert!(is_chat_template_file("chat_template.json"));
assert!(!is_chat_template_file("tokenizer.json"));
assert!(!is_chat_template_file("other_file.json"));
assert!(!is_chat_template_file("chat_template"));
assert!(!is_chat_template_file("README.md"));
}
#[test]
fn test_is_weight_file() {
assert!(is_weight_file("model.bin"));

View File

@@ -25,7 +25,12 @@ pub struct HuggingFaceTokenizer {
impl HuggingFaceTokenizer {
/// Create a tokenizer from a HuggingFace tokenizer JSON file
pub fn from_file(file_path: &str) -> Result<Self> {
Self::from_file_with_chat_template(file_path, None)
// Try to auto-discover chat template if not explicitly provided
let path = std::path::Path::new(file_path);
let chat_template_path = path
.parent()
.and_then(crate::tokenizer::factory::discover_chat_template_in_dir);
Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
}
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
@@ -135,13 +140,35 @@ impl HuggingFaceTokenizer {
None
}
/// Load chat template from a .jinja file
/// Load chat template from a file (.jinja or .json containing Jinja)
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(template_path)
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
// Check if it's a JSON file containing a Jinja template
if template_path.ends_with(".json") {
// Parse JSON and extract the template string
let json_value: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| Error::msg(format!("Failed to parse chat_template.json: {}", e)))?;
if let Some(template_str) = json_value.as_str() {
return Ok(Some(template_str.to_string()));
} else if let Some(obj) = json_value.as_object() {
if let Some(template_value) = obj.get("chat_template") {
if let Some(template_str) = template_value.as_str() {
return Ok(Some(template_str.to_string()));
}
}
}
return Err(Error::msg(
"chat_template.json does not contain a valid template",
));
}
// Otherwise it's a plain .jinja file
// Clean up the template (similar to Python implementation)
let template = content.trim().replace("\\n", "\n");