[router] Add Rerank Routing Logic in Regular Router (#10219)

This commit is contained in:
Frank Fang
2025-09-13 00:10:18 +08:00
committed by GitHub
parent efedbe6ca9
commit 4634fd5953
10 changed files with 545 additions and 40 deletions

View File

@@ -1891,7 +1891,7 @@ pub struct RerankResponse {
pub object: String,
/// Response ID
pub id: String,
pub id: Option<StringOrArray>,
/// Creation timestamp
pub created: i64,
@@ -1976,7 +1976,11 @@ impl RerankRequest {
}
impl RerankResponse {
pub fn new(results: Vec<RerankResult>, model: String, request_id: String) -> Self {
pub fn new(
results: Vec<RerankResult>,
model: String,
request_id: Option<StringOrArray>,
) -> Self {
RerankResponse {
results,
model,
@@ -2000,6 +2004,13 @@ impl RerankResponse {
pub fn apply_top_k(&mut self, k: usize) {
self.results.truncate(k);
}
/// Drop documents from results
pub fn drop_documents(&mut self) {
self.results.iter_mut().for_each(|result| {
result.document = None;
});
}
}
// ==================================================================
@@ -2268,12 +2279,15 @@ mod tests {
let response = RerankResponse::new(
results.clone(),
"test-model".to_string(),
"req-123".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
assert_eq!(response.results.len(), 2);
assert_eq!(response.model, "test-model");
assert_eq!(response.id, "req-123");
assert_eq!(
response.id,
Some(StringOrArray::String("req-123".to_string()))
);
assert_eq!(response.object, "rerank");
assert!(response.created > 0);
}
@@ -2287,8 +2301,11 @@ mod tests {
meta_info: None,
}];
let response =
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
let response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
let serialized = serde_json::to_string(&response).unwrap();
let deserialized: RerankResponse = serde_json::from_str(&serialized).unwrap();
@@ -2322,8 +2339,11 @@ mod tests {
},
];
let mut response =
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.sort_by_score();
@@ -2358,8 +2378,11 @@ mod tests {
},
];
let mut response =
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(2);
@@ -2377,14 +2400,36 @@ mod tests {
meta_info: None,
}];
let mut response =
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.apply_top_k(5);
assert_eq!(response.results.len(), 1);
}
#[test]
fn test_rerank_response_drop_documents() {
let results = vec![RerankResult {
score: 0.8,
document: Some("doc1".to_string()),
index: 0,
meta_info: None,
}];
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.drop_documents();
assert_eq!(response.results[0].document, None);
}
// ==================================================================
// = RERANK RESULT TESTS =
// ==================================================================
@@ -2570,8 +2615,11 @@ mod tests {
meta_info: None,
}];
let mut response =
RerankResponse::new(results, "test-model".to_string(), "req-123".to_string());
let mut response = RerankResponse::new(
results,
"test-model".to_string(),
Some(StringOrArray::String("req-123".to_string())),
);
response.usage = Some(UsageInfo {
prompt_tokens: 100,
@@ -2645,18 +2693,7 @@ mod tests {
];
// Create response
let mut response = RerankResponse::new(
results,
request.model.clone(),
request
.rid
.as_ref()
.and_then(|r| match r {
StringOrArray::String(s) => Some(s.clone()),
StringOrArray::Array(arr) => arr.first().cloned(),
})
.unwrap_or_else(|| "unknown".to_string()),
);
let mut response = RerankResponse::new(results, request.model.clone(), request.rid.clone());
// Sort by score
response.sort_by_score();