[router]: Add Embedding routing logic (#10129)
Signed-off-by: Jintao Zhang <zhangjintao9020@gmail.com> Co-authored-by: Waël Boukhobza <wawa_wael@live.fr>
This commit is contained in:
@@ -41,7 +41,10 @@ use std::collections::HashMap;
|
||||
// 6. **SGLANG SPEC - RERANK API**
|
||||
// - Request/Response structures
|
||||
//
|
||||
// 7. **COMMON**
|
||||
// 7. **OPENAI SPEC - Embeddings API**
|
||||
// - Request structures
|
||||
//
|
||||
// 8. **COMMON**
|
||||
// - GenerationRequest trait
|
||||
// - StringOrArray & LoRAPath types
|
||||
// - Helper functions
|
||||
@@ -2013,6 +2016,61 @@ impl RerankResponse {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = OPENAI SPEC - Embeddings API =
|
||||
// ==================================================================
|
||||
|
||||
/// Embeddings request compatible with OpenAI API
|
||||
/// We intentionally keep fields flexible to pass through to workers.
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
/// ID of the model to use
|
||||
pub model: String,
|
||||
|
||||
/// Input can be a string, array of strings, tokens, or batch inputs
|
||||
pub input: serde_json::Value,
|
||||
|
||||
/// Optional encoding format (e.g., "float", "base64")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<String>,
|
||||
|
||||
/// Optional user identifier
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// Optional number of dimensions for the embedding
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<u32>,
|
||||
|
||||
/// SGLang extension: request id for tracking
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rid: Option<String>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for EmbeddingRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
// Embeddings are non-streaming
|
||||
false
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Best effort: extract text content for routing decisions
|
||||
match &self.input {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
serde_json::Value::Array(arr) => arr
|
||||
.iter()
|
||||
.filter_map(|v| v.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" "),
|
||||
_ => String::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = COMMON =
|
||||
// ==================================================================
|
||||
@@ -2715,4 +2773,102 @@ mod tests {
|
||||
assert_eq!(deserialized.results.len(), 2);
|
||||
assert_eq!(deserialized.model, response.model);
|
||||
}
|
||||
|
||||
// ==================================================================
|
||||
// = EMBEDDINGS REQUEST TESTS =
|
||||
// ==================================================================
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_string_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: serde_json::Value::String("hello".to_string()),
|
||||
encoding_format: Some("float".to_string()),
|
||||
user: Some("user-1".to_string()),
|
||||
dimensions: Some(128),
|
||||
rid: Some("rid-123".to_string()),
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
let deserialized: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.model, req.model);
|
||||
assert_eq!(deserialized.input, req.input);
|
||||
assert_eq!(deserialized.encoding_format, req.encoding_format);
|
||||
assert_eq!(deserialized.user, req.user);
|
||||
assert_eq!(deserialized.dimensions, req.dimensions);
|
||||
assert_eq!(deserialized.rid, req.rid);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_request_serialization_array_input() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "test-emb".to_string(),
|
||||
input: serde_json::json!(["a", "b", "c"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
let de: EmbeddingRequest = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(de.model, req.model);
|
||||
assert_eq!(de.input, req.input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_string() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::Value::String("hello".to_string()),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert!(!req.is_stream());
|
||||
assert_eq!(req.get_model(), Some("emb-model"));
|
||||
assert_eq!(req.extract_text_for_routing(), "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_array() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!(["hello", "world"]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert_eq!(req.extract_text_for_routing(), "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_non_text() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!({"tokens": [1, 2, 3]}),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
assert_eq!(req.extract_text_for_routing(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_generation_request_trait_mixed_array_ignores_nested() {
|
||||
let req = EmbeddingRequest {
|
||||
model: "emb-model".to_string(),
|
||||
input: serde_json::json!(["a", ["b", "c"], 123, {"k": "v"}]),
|
||||
encoding_format: None,
|
||||
user: None,
|
||||
dimensions: None,
|
||||
rid: None,
|
||||
};
|
||||
// Only top-level string elements are extracted
|
||||
assert_eq!(req.extract_text_for_routing(), "a");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user