[router] Fix response api related spec (#11621)
This commit is contained in:
@@ -69,6 +69,8 @@ fn generate_request_id(path: &str) -> String {
|
||||
"cmpl-"
|
||||
} else if path.contains("/generate") {
|
||||
"gnt-"
|
||||
} else if path.contains("/responses") {
|
||||
"resp-"
|
||||
} else {
|
||||
"req-"
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{to_value, Map, Number, Value};
|
||||
use serde_json::{to_value, Map, Value};
|
||||
use std::collections::HashMap;
|
||||
use validator::Validate;
|
||||
|
||||
@@ -1325,10 +1325,6 @@ impl ResponsesUsage {
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_request_id() -> String {
|
||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResponsesRequest {
|
||||
/// Run the request in the background
|
||||
@@ -1419,8 +1415,8 @@ pub struct ResponsesRequest {
|
||||
pub user: Option<String>,
|
||||
|
||||
/// Request ID
|
||||
#[serde(default = "generate_request_id")]
|
||||
pub request_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub request_id: Option<String>,
|
||||
|
||||
/// Request priority
|
||||
#[serde(default)]
|
||||
@@ -1438,15 +1434,15 @@ pub struct ResponsesRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
|
||||
/// Top-k sampling parameter
|
||||
/// Top-k sampling parameter (SGLang extension)
|
||||
#[serde(default = "default_top_k")]
|
||||
pub top_k: i32,
|
||||
|
||||
/// Min-p sampling parameter
|
||||
/// Min-p sampling parameter (SGLang extension)
|
||||
#[serde(default)]
|
||||
pub min_p: f32,
|
||||
|
||||
/// Repetition penalty
|
||||
/// Repetition penalty (SGLang extension)
|
||||
#[serde(default = "default_repetition_penalty")]
|
||||
pub repetition_penalty: f32,
|
||||
}
|
||||
@@ -1491,7 +1487,7 @@ impl Default for ResponsesRequest {
|
||||
top_p: None,
|
||||
truncation: None,
|
||||
user: None,
|
||||
request_id: generate_request_id(),
|
||||
request_id: None,
|
||||
priority: 0,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
@@ -1503,101 +1499,6 @@ impl Default for ResponsesRequest {
|
||||
}
|
||||
}
|
||||
|
||||
impl ResponsesRequest {
|
||||
/// Default sampling parameters
|
||||
const DEFAULT_TEMPERATURE: f32 = 0.7;
|
||||
const DEFAULT_TOP_P: f32 = 1.0;
|
||||
|
||||
/// Convert to sampling parameters for generation
|
||||
pub fn to_sampling_params(
|
||||
&self,
|
||||
default_max_tokens: u32,
|
||||
default_params: Option<HashMap<String, Value>>,
|
||||
) -> HashMap<String, Value> {
|
||||
let mut params = HashMap::new();
|
||||
|
||||
// Use max_output_tokens if available
|
||||
let max_tokens = if let Some(max_output) = self.max_output_tokens {
|
||||
std::cmp::min(max_output, default_max_tokens)
|
||||
} else {
|
||||
default_max_tokens
|
||||
};
|
||||
|
||||
// Avoid exceeding context length by minus 1 token
|
||||
let max_tokens = max_tokens.saturating_sub(1);
|
||||
|
||||
// Temperature
|
||||
let temperature = self.temperature.unwrap_or_else(|| {
|
||||
default_params
|
||||
.as_ref()
|
||||
.and_then(|p| p.get("temperature"))
|
||||
.and_then(|v| v.as_f64())
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(Self::DEFAULT_TEMPERATURE)
|
||||
});
|
||||
|
||||
// Top-p
|
||||
let top_p = self.top_p.unwrap_or_else(|| {
|
||||
default_params
|
||||
.as_ref()
|
||||
.and_then(|p| p.get("top_p"))
|
||||
.and_then(|v| v.as_f64())
|
||||
.map(|v| v as f32)
|
||||
.unwrap_or(Self::DEFAULT_TOP_P)
|
||||
});
|
||||
|
||||
params.insert(
|
||||
"max_new_tokens".to_string(),
|
||||
Value::Number(Number::from(max_tokens)),
|
||||
);
|
||||
params.insert(
|
||||
"temperature".to_string(),
|
||||
Value::Number(Number::from_f64(temperature as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"top_p".to_string(),
|
||||
Value::Number(Number::from_f64(top_p as f64).unwrap()),
|
||||
);
|
||||
if let Some(fp) = self.frequency_penalty {
|
||||
params.insert(
|
||||
"frequency_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(fp as f64).unwrap()),
|
||||
);
|
||||
}
|
||||
if let Some(pp) = self.presence_penalty {
|
||||
params.insert(
|
||||
"presence_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(pp as f64).unwrap()),
|
||||
);
|
||||
}
|
||||
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
|
||||
params.insert(
|
||||
"min_p".to_string(),
|
||||
Value::Number(Number::from_f64(self.min_p as f64).unwrap()),
|
||||
);
|
||||
params.insert(
|
||||
"repetition_penalty".to_string(),
|
||||
Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()),
|
||||
);
|
||||
|
||||
if let Some(ref stop) = self.stop {
|
||||
match to_value(stop) {
|
||||
Ok(value) => params.insert("stop".to_string(), value),
|
||||
Err(_) => params.insert("stop".to_string(), Value::Null),
|
||||
};
|
||||
}
|
||||
|
||||
// Apply any additional default parameters
|
||||
if let Some(default_params) = default_params {
|
||||
for (key, value) in default_params {
|
||||
params.entry(key).or_insert(value);
|
||||
}
|
||||
}
|
||||
|
||||
params
|
||||
}
|
||||
}
|
||||
|
||||
impl GenerationRequest for ResponsesRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream.unwrap_or(false)
|
||||
@@ -1776,7 +1677,10 @@ impl ResponsesResponse {
|
||||
usage: Option<UsageInfo>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: request.request_id.clone(),
|
||||
id: request
|
||||
.request_id
|
||||
.clone()
|
||||
.expect("request_id should be set by middleware"),
|
||||
object: "response".to_string(),
|
||||
created_at: created_time,
|
||||
status,
|
||||
@@ -2535,9 +2439,6 @@ pub enum GenerateFinishReason {
|
||||
Other(Value),
|
||||
}
|
||||
|
||||
// Constants for rerank API
|
||||
pub const DEFAULT_MODEL_NAME: &str = "default";
|
||||
|
||||
/// Rerank request for scoring documents against a query
|
||||
/// Used for RAG systems and document relevance scoring
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -2549,7 +2450,7 @@ pub struct RerankRequest {
|
||||
pub documents: Vec<String>,
|
||||
|
||||
/// Model to use for reranking
|
||||
#[serde(default = "default_model_name")]
|
||||
#[serde(default = "default_model")]
|
||||
pub model: String,
|
||||
|
||||
/// Maximum number of documents to return (optional)
|
||||
@@ -2567,10 +2468,6 @@ pub struct RerankRequest {
|
||||
pub user: Option<String>,
|
||||
}
|
||||
|
||||
pub fn default_model_name() -> String {
|
||||
DEFAULT_MODEL_NAME.to_string()
|
||||
}
|
||||
|
||||
fn default_return_documents() -> bool {
|
||||
true
|
||||
}
|
||||
@@ -2634,7 +2531,7 @@ impl From<V1RerankReqInput> for RerankRequest {
|
||||
RerankRequest {
|
||||
query: v1.query,
|
||||
documents: v1.documents,
|
||||
model: default_model_name(),
|
||||
model: default_model(),
|
||||
top_k: None,
|
||||
return_documents: true,
|
||||
rid: None,
|
||||
|
||||
@@ -2156,7 +2156,7 @@ mod rerank_tests {
|
||||
assert!(body_json.get("model").is_some());
|
||||
|
||||
// V1 API should use default model name
|
||||
assert_eq!(body_json["model"], "default");
|
||||
assert_eq!(body_json["model"], "unknown");
|
||||
|
||||
let results = body_json["results"].as_array().unwrap();
|
||||
assert_eq!(results.len(), 3); // All documents should be returned
|
||||
|
||||
@@ -115,7 +115,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
|
||||
top_p: None,
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_test_mcp_e2e".to_string(),
|
||||
request_id: Some("resp_test_mcp_e2e".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
@@ -361,7 +361,7 @@ fn test_responses_request_creation() {
|
||||
top_p: Some(0.9),
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: Some("test-user".to_string()),
|
||||
request_id: "resp_test123".to_string(),
|
||||
request_id: Some("resp_test123".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
@@ -379,7 +379,8 @@ fn test_responses_request_creation() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_conversion() {
|
||||
fn test_responses_request_sglang_extensions() {
|
||||
// Test that SGLang-specific sampling parameters are present and serializable
|
||||
let request = ResponsesRequest {
|
||||
background: Some(false),
|
||||
include: None,
|
||||
@@ -389,37 +390,44 @@ fn test_sampling_params_conversion() {
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: Some(true), // Use default true
|
||||
parallel_tool_calls: Some(true),
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: Some(ServiceTier::Auto),
|
||||
store: Some(true), // Use default true
|
||||
store: Some(true),
|
||||
stream: Some(false),
|
||||
temperature: Some(0.8),
|
||||
tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)),
|
||||
tools: Some(vec![]),
|
||||
top_logprobs: Some(0), // Use default 0
|
||||
top_logprobs: Some(0),
|
||||
top_p: Some(0.95),
|
||||
truncation: Some(Truncation::Auto),
|
||||
user: None,
|
||||
request_id: "resp_test456".to_string(),
|
||||
request_id: Some("resp_test456".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.1),
|
||||
presence_penalty: Some(0.2),
|
||||
stop: None,
|
||||
// SGLang-specific extensions:
|
||||
top_k: 10,
|
||||
min_p: 0.05,
|
||||
repetition_penalty: 1.1,
|
||||
conversation: None,
|
||||
};
|
||||
|
||||
let params = request.to_sampling_params(1000, None);
|
||||
// Verify SGLang extensions are present
|
||||
assert_eq!(request.top_k, 10);
|
||||
assert_eq!(request.min_p, 0.05);
|
||||
assert_eq!(request.repetition_penalty, 1.1);
|
||||
|
||||
// Check that parameters are converted correctly
|
||||
assert!(params.contains_key("temperature"));
|
||||
assert!(params.contains_key("top_p"));
|
||||
assert!(params.contains_key("frequency_penalty"));
|
||||
assert!(params.contains_key("max_new_tokens"));
|
||||
// Verify serialization works with SGLang extensions
|
||||
let json = serde_json::to_string(&request).expect("Serialization should work");
|
||||
let parsed: ResponsesRequest =
|
||||
serde_json::from_str(&json).expect("Deserialization should work");
|
||||
|
||||
assert_eq!(parsed.top_k, 10);
|
||||
assert_eq!(parsed.min_p, 0.05);
|
||||
assert_eq!(parsed.repetition_penalty, 1.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -516,7 +524,7 @@ fn test_json_serialization() {
|
||||
top_p: Some(0.8),
|
||||
truncation: Some(Truncation::Auto),
|
||||
user: Some("test_user".to_string()),
|
||||
request_id: "resp_comprehensive_test".to_string(),
|
||||
request_id: Some("resp_comprehensive_test".to_string()),
|
||||
priority: 1,
|
||||
frequency_penalty: Some(0.3),
|
||||
presence_penalty: Some(0.4),
|
||||
@@ -531,7 +539,10 @@ fn test_json_serialization() {
|
||||
let parsed: ResponsesRequest =
|
||||
serde_json::from_str(&json).expect("Deserialization should work");
|
||||
|
||||
assert_eq!(parsed.request_id, "resp_comprehensive_test");
|
||||
assert_eq!(
|
||||
parsed.request_id,
|
||||
Some("resp_comprehensive_test".to_string())
|
||||
);
|
||||
assert_eq!(parsed.model, Some("gpt-4".to_string()));
|
||||
assert_eq!(parsed.background, Some(true));
|
||||
assert_eq!(parsed.stream, Some(true));
|
||||
@@ -643,7 +654,7 @@ async fn test_multi_turn_loop_with_mcp() {
|
||||
top_p: Some(1.0),
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_multi_turn_test".to_string(),
|
||||
request_id: Some("resp_multi_turn_test".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
@@ -816,7 +827,7 @@ async fn test_max_tool_calls_limit() {
|
||||
top_p: Some(1.0),
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_max_calls_test".to_string(),
|
||||
request_id: Some("resp_max_calls_test".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
@@ -1011,7 +1022,7 @@ async fn test_streaming_with_mcp_tool_calls() {
|
||||
top_p: Some(1.0),
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_streaming_mcp_test".to_string(),
|
||||
request_id: Some("resp_streaming_mcp_test".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
@@ -1290,7 +1301,7 @@ async fn test_streaming_multi_turn_with_mcp() {
|
||||
top_p: Some(1.0),
|
||||
truncation: Some(Truncation::Disabled),
|
||||
user: None,
|
||||
request_id: "resp_streaming_multiturn_test".to_string(),
|
||||
request_id: Some("resp_streaming_multiturn_test".to_string()),
|
||||
priority: 0,
|
||||
frequency_penalty: Some(0.0),
|
||||
presence_penalty: Some(0.0),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use serde_json::{from_str, to_string, Number, Value};
|
||||
use sglang_router_rs::protocols::spec::{
|
||||
default_model_name, GenerationRequest, RerankRequest, RerankResponse, RerankResult,
|
||||
StringOrArray, UsageInfo, V1RerankReqInput,
|
||||
GenerationRequest, RerankRequest, RerankResponse, RerankResult, StringOrArray, UsageInfo,
|
||||
V1RerankReqInput,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
@@ -40,7 +40,7 @@ fn test_rerank_request_deserialization_with_defaults() {
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
assert_eq!(request.model, default_model_name());
|
||||
assert_eq!(request.model, "unknown");
|
||||
assert_eq!(request.top_k, None);
|
||||
assert!(request.return_documents);
|
||||
assert_eq!(request.rid, None);
|
||||
@@ -414,7 +414,7 @@ fn test_v1_to_rerank_request_conversion() {
|
||||
|
||||
assert_eq!(request.query, "test query");
|
||||
assert_eq!(request.documents, vec!["doc1", "doc2"]);
|
||||
assert_eq!(request.model, default_model_name());
|
||||
assert_eq!(request.model, "unknown");
|
||||
assert_eq!(request.top_k, None);
|
||||
assert!(request.return_documents);
|
||||
assert_eq!(request.rid, None);
|
||||
|
||||
Reference in New Issue
Block a user