[router] fix chat template loading and tokenizer path (#10999)
This commit is contained in:
@@ -40,6 +40,13 @@ fn is_tokenizer_file(filename: &str) -> bool {
|
||||
|| filename.ends_with("merges.txt")
|
||||
|| filename.ends_with(".model") // SentencePiece models
|
||||
|| filename.ends_with(".tiktoken")
|
||||
|| is_chat_template_file(filename) // Include chat template files
|
||||
}
|
||||
|
||||
/// Checks if a file is a chat template file
|
||||
fn is_chat_template_file(filename: &str) -> bool {
|
||||
filename.ends_with(".jinja") // Direct Jinja files
|
||||
|| filename == "chat_template.json" // JSON file containing Jinja template
|
||||
}
|
||||
|
||||
/// Attempt to download tokenizer files from Hugging Face
|
||||
@@ -123,7 +130,13 @@ pub async fn download_tokenizer_from_hf(model_id: impl AsRef<Path>) -> anyhow::R
|
||||
}
|
||||
|
||||
match cache_dir {
|
||||
Some(dir) => Ok(dir),
|
||||
Some(dir) => {
|
||||
// Ensure we return the correct model directory, not a subfolder
|
||||
// Some models have an "original" subfolder for PyTorch weights
|
||||
// We want the main model directory that contains tokenizer files
|
||||
let final_dir = resolve_model_cache_dir(&dir, &model_name);
|
||||
Ok(final_dir)
|
||||
}
|
||||
None => Err(anyhow::anyhow!(
|
||||
"Invalid HF cache path for model '{}'",
|
||||
model_name
|
||||
@@ -206,11 +219,76 @@ pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Re
|
||||
}
|
||||
|
||||
match p.parent() {
|
||||
Some(p) => Ok(p.to_path_buf()),
|
||||
Some(p) => {
|
||||
let final_dir = resolve_model_cache_dir(p, &model_name);
|
||||
Ok(final_dir)
|
||||
}
|
||||
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the correct model cache directory
|
||||
/// Handles cases where files might be in subfolders (e.g., "original" folder)
|
||||
fn resolve_model_cache_dir(path: &Path, model_name: &str) -> PathBuf {
|
||||
// Check if we're in a subfolder like "original"
|
||||
if let Some(parent) = path.parent() {
|
||||
if let Some(folder_name) = path.file_name() {
|
||||
if folder_name == "original" {
|
||||
// We're in the "original" subfolder, go up one level
|
||||
return parent.to_path_buf();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the current path contains the model name components
|
||||
// This helps ensure we're at the right directory level
|
||||
let model_parts: Vec<&str> = model_name.split('/').collect();
|
||||
if model_parts.len() >= 2 {
|
||||
let expected_pattern = format!(
|
||||
"models--{}--{}",
|
||||
model_parts[0].replace("-", "--"),
|
||||
model_parts[1].replace("-", "--")
|
||||
);
|
||||
|
||||
if path.to_string_lossy().contains(&expected_pattern) {
|
||||
// We're already at the correct level
|
||||
return path.to_path_buf();
|
||||
}
|
||||
|
||||
let mut current = path.to_path_buf();
|
||||
|
||||
// First check if current path already contains tokenizer files
|
||||
if current.join("tokenizer.json").exists() || current.join("tokenizer_config.json").exists()
|
||||
{
|
||||
return current;
|
||||
}
|
||||
|
||||
// If not, traverse up to find the model root, then look in snapshots
|
||||
while let Some(parent) = current.parent() {
|
||||
if parent.to_string_lossy().contains(&expected_pattern) {
|
||||
let snapshots_dir = parent.join("snapshots");
|
||||
if snapshots_dir.exists() && snapshots_dir.is_dir() {
|
||||
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
|
||||
for entry in entries.flatten() {
|
||||
let snapshot_path = entry.path();
|
||||
if snapshot_path.is_dir()
|
||||
&& (snapshot_path.join("tokenizer.json").exists()
|
||||
|| snapshot_path.join("tokenizer_config.json").exists())
|
||||
{
|
||||
return snapshot_path;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return parent.to_path_buf();
|
||||
}
|
||||
current = parent.to_path_buf();
|
||||
}
|
||||
}
|
||||
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -223,10 +301,24 @@ mod tests {
|
||||
assert!(is_tokenizer_file("vocab.json"));
|
||||
assert!(is_tokenizer_file("merges.txt"));
|
||||
assert!(is_tokenizer_file("spiece.model"));
|
||||
assert!(is_tokenizer_file("chat_template.jinja"));
|
||||
assert!(is_tokenizer_file("template.jinja"));
|
||||
assert!(!is_tokenizer_file("model.bin"));
|
||||
assert!(!is_tokenizer_file("README.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_chat_template_file() {
|
||||
assert!(is_chat_template_file("chat_template.jinja"));
|
||||
assert!(is_chat_template_file("template.jinja"));
|
||||
assert!(is_chat_template_file("any_file.jinja"));
|
||||
assert!(is_chat_template_file("chat_template.json"));
|
||||
assert!(!is_chat_template_file("tokenizer.json"));
|
||||
assert!(!is_chat_template_file("other_file.json"));
|
||||
assert!(!is_chat_template_file("chat_template"));
|
||||
assert!(!is_chat_template_file("README.md"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_weight_file() {
|
||||
assert!(is_weight_file("model.bin"));
|
||||
|
||||
Reference in New Issue
Block a user