[router][grpc] Simplify model_id determination (#11684)
This commit is contained in:
@@ -22,7 +22,6 @@ use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{watch, Mutex};
|
||||
@@ -1228,8 +1227,7 @@ impl WorkerManager {
|
||||
Ok(server_info) => {
|
||||
if let Some(model_id) = server_info.model_id {
|
||||
if !model_id.is_empty() {
|
||||
let normalized = Self::normalize_model_identifier(&model_id);
|
||||
discovery.labels.insert("model_id".to_string(), normalized);
|
||||
discovery.labels.insert("model_id".to_string(), model_id);
|
||||
}
|
||||
}
|
||||
if let Some(model_path) = server_info.model_path {
|
||||
@@ -1309,9 +1307,9 @@ impl WorkerManager {
|
||||
"served_model_name".to_string(),
|
||||
model_info.served_model_name.clone(),
|
||||
);
|
||||
let normalized =
|
||||
Self::normalize_model_identifier(&model_info.served_model_name);
|
||||
discovery.labels.insert("model_id".to_string(), normalized);
|
||||
discovery
|
||||
.labels
|
||||
.insert("model_id".to_string(), model_info.served_model_name);
|
||||
}
|
||||
if !model_info.weight_version.is_empty() {
|
||||
discovery.labels.insert(
|
||||
@@ -1368,15 +1366,6 @@ impl WorkerManager {
|
||||
discovery
|
||||
}
|
||||
|
||||
fn normalize_model_identifier(value: &str) -> String {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.contains('/') || trimmed.contains('\\') {
|
||||
Self::derive_model_id_from_path(trimmed)
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_model_id(labels: &mut HashMap<String, String>) {
|
||||
let has_model_id = labels
|
||||
.get("model_id")
|
||||
@@ -1386,41 +1375,20 @@ impl WorkerManager {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(served_name) = labels.get("served_model_name").cloned() {
|
||||
if let Some(served_name) = labels.get("served_model_name") {
|
||||
if !served_name.trim().is_empty() {
|
||||
let normalized = Self::normalize_model_identifier(&served_name);
|
||||
labels.insert("model_id".to_string(), normalized);
|
||||
labels.insert("model_id".to_string(), served_name.clone());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(model_path) = labels.get("model_path").cloned() {
|
||||
if let Some(model_path) = labels.get("model_path") {
|
||||
if !model_path.trim().is_empty() {
|
||||
let derived = Self::derive_model_id_from_path(&model_path);
|
||||
if !derived.is_empty() {
|
||||
labels.insert("model_id".to_string(), derived);
|
||||
}
|
||||
labels.insert("model_id".to_string(), model_path.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_model_id_from_path(path: &str) -> String {
|
||||
let trimmed = path.trim_end_matches(['/', '\\']);
|
||||
if trimmed.is_empty() {
|
||||
return path.to_string();
|
||||
}
|
||||
|
||||
let candidate = Path::new(trimmed)
|
||||
.file_name()
|
||||
.and_then(|p| p.to_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
match candidate {
|
||||
Some(name) if !name.is_empty() => name,
|
||||
_ => trimmed.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse server info from JSON response
|
||||
fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
|
||||
Ok(ServerInfo {
|
||||
@@ -1872,20 +1840,6 @@ mod tests {
|
||||
assert_eq!(info.dp_size, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_model_id_from_path() {
|
||||
let path = "/raid/models/meta-llama/Llama-3.1-8B-Instruct";
|
||||
let derived = WorkerManager::derive_model_id_from_path(path);
|
||||
assert_eq!(derived, "Llama-3.1-8B-Instruct");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_model_id_trailing_slash() {
|
||||
let path = "/models/foo/bar/";
|
||||
let derived = WorkerManager::derive_model_id_from_path(path);
|
||||
assert_eq!(derived, "bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finalize_model_id_prefers_existing() {
|
||||
let mut labels = HashMap::new();
|
||||
@@ -1908,12 +1862,6 @@ mod tests {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("model_path".to_string(), "/models/alpha".to_string());
|
||||
WorkerManager::finalize_model_id(&mut labels);
|
||||
assert_eq!(labels.get("model_id").unwrap(), "alpha");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_model_identifier_from_path() {
|
||||
let normalized = WorkerManager::normalize_model_identifier("/raid/models/foo/bar-model");
|
||||
assert_eq!(normalized, "bar-model");
|
||||
assert_eq!(labels.get("model_id").unwrap(), "/models/alpha");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user