[router] Fix response api related spec (#11621)

This commit is contained in:
Keyang Ru
2025-10-15 09:59:38 -07:00
committed by GitHub
parent 30ea4c462b
commit d2478cd4ff
5 changed files with 50 additions and 140 deletions

View File

@@ -2156,7 +2156,7 @@ mod rerank_tests {
assert!(body_json.get("model").is_some());
// V1 API should use default model name
assert_eq!(body_json["model"], "default");
assert_eq!(body_json["model"], "unknown");
let results = body_json["results"].as_array().unwrap();
assert_eq!(results.len(), 3); // All documents should be returned

View File

@@ -115,7 +115,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
top_p: None,
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_test_mcp_e2e".to_string(),
request_id: Some("resp_test_mcp_e2e".to_string()),
priority: 0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
@@ -361,7 +361,7 @@ fn test_responses_request_creation() {
top_p: Some(0.9),
truncation: Some(Truncation::Disabled),
user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(),
request_id: Some("resp_test123".to_string()),
priority: 0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
@@ -379,7 +379,8 @@ fn test_responses_request_creation() {
}
#[test]
fn test_sampling_params_conversion() {
fn test_responses_request_sglang_extensions() {
// Test that SGLang-specific sampling parameters are present and serializable
let request = ResponsesRequest {
background: Some(false),
include: None,
@@ -389,37 +390,44 @@ fn test_sampling_params_conversion() {
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: Some(true), // Use default true
parallel_tool_calls: Some(true),
previous_response_id: None,
reasoning: None,
service_tier: Some(ServiceTier::Auto),
store: Some(true), // Use default true
store: Some(true),
stream: Some(false),
temperature: Some(0.8),
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
tools: Some(vec![]),
top_logprobs: Some(0), // Use default 0
top_logprobs: Some(0),
top_p: Some(0.95),
truncation: Some(Truncation::Auto),
user: None,
request_id: "resp_test456".to_string(),
request_id: Some("resp_test456".to_string()),
priority: 0,
frequency_penalty: Some(0.1),
presence_penalty: Some(0.2),
stop: None,
// SGLang-specific extensions:
top_k: 10,
min_p: 0.05,
repetition_penalty: 1.1,
conversation: None,
};
let params = request.to_sampling_params(1000, None);
// Verify SGLang extensions are present
assert_eq!(request.top_k, 10);
assert_eq!(request.min_p, 0.05);
assert_eq!(request.repetition_penalty, 1.1);
// Check that parameters are converted correctly
assert!(params.contains_key("temperature"));
assert!(params.contains_key("top_p"));
assert!(params.contains_key("frequency_penalty"));
assert!(params.contains_key("max_new_tokens"));
// Verify serialization works with SGLang extensions
let json = serde_json::to_string(&request).expect("Serialization should work");
let parsed: ResponsesRequest =
serde_json::from_str(&json).expect("Deserialization should work");
assert_eq!(parsed.top_k, 10);
assert_eq!(parsed.min_p, 0.05);
assert_eq!(parsed.repetition_penalty, 1.1);
}
#[test]
@@ -516,7 +524,7 @@ fn test_json_serialization() {
top_p: Some(0.8),
truncation: Some(Truncation::Auto),
user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(),
request_id: Some("resp_comprehensive_test".to_string()),
priority: 1,
frequency_penalty: Some(0.3),
presence_penalty: Some(0.4),
@@ -531,7 +539,10 @@ fn test_json_serialization() {
let parsed: ResponsesRequest =
serde_json::from_str(&json).expect("Deserialization should work");
assert_eq!(parsed.request_id, "resp_comprehensive_test");
assert_eq!(
parsed.request_id,
Some("resp_comprehensive_test".to_string())
);
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert_eq!(parsed.background, Some(true));
assert_eq!(parsed.stream, Some(true));
@@ -643,7 +654,7 @@ async fn test_multi_turn_loop_with_mcp() {
top_p: Some(1.0),
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_multi_turn_test".to_string(),
request_id: Some("resp_multi_turn_test".to_string()),
priority: 0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
@@ -816,7 +827,7 @@ async fn test_max_tool_calls_limit() {
top_p: Some(1.0),
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_max_calls_test".to_string(),
request_id: Some("resp_max_calls_test".to_string()),
priority: 0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
@@ -1011,7 +1022,7 @@ async fn test_streaming_with_mcp_tool_calls() {
top_p: Some(1.0),
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_streaming_mcp_test".to_string(),
request_id: Some("resp_streaming_mcp_test".to_string()),
priority: 0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),
@@ -1290,7 +1301,7 @@ async fn test_streaming_multi_turn_with_mcp() {
top_p: Some(1.0),
truncation: Some(Truncation::Disabled),
user: None,
request_id: "resp_streaming_multiturn_test".to_string(),
request_id: Some("resp_streaming_multiturn_test".to_string()),
priority: 0,
frequency_penalty: Some(0.0),
presence_penalty: Some(0.0),

View File

@@ -1,7 +1,7 @@
use serde_json::{from_str, to_string, Number, Value};
use sglang_router_rs::protocols::spec::{
default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult,
StringOrArray, UsageInfo, V1RerankReqInput,
GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo,
V1RerankReqInput,
};
use std::collections::HashMap;
@@ -40,7 +40,7 @@ fn test_rerank_request_deserialization_with_defaults() {
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.model, "unknown");
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);
@@ -414,7 +414,7 @@ fn test_v1_to_rerank_request_conversion() {
assert_eq!(request.query, "test query");
assert_eq!(request.documents, vec!["doc1", "doc2"]);
assert_eq!(request.model, default_model_name());
assert_eq!(request.model, "unknown");
assert_eq!(request.top_k, None);
assert!(request.return_documents);
assert_eq!(request.rid, None);