[router] Fix response api related spec (#11621)

This commit is contained in:
Keyang Ru
2025-10-15 09:59:38 -07:00
committed by GitHub
parent 30ea4c462b
commit d2478cd4ff
5 changed files with 50 additions and 140 deletions

View File

@@ -1,5 +1,5 @@
use serde::{Deserialize, Serialize};
use serde_json::{to_value, Map, Number, Value};
use serde_json::{to_value, Map, Value};
use std::collections::HashMap;
use validator::Validate;
@@ -1325,10 +1325,6 @@ impl ResponsesUsage {
}
}
fn generate_request_id() -> String {
format!("resp_{}", uuid::Uuid::new_v4().simple())
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ResponsesRequest {
/// Run the request in the background
@@ -1419,8 +1415,8 @@ pub struct ResponsesRequest {
pub user: Option<String>,
/// Request ID
#[serde(default = "generate_request_id")]
pub request_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
/// Request priority
#[serde(default)]
@@ -1438,15 +1434,15 @@ pub struct ResponsesRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Top-k sampling parameter
/// Top-k sampling parameter (SGLang extension)
#[serde(default = "default_top_k")]
pub top_k: i32,
/// Min-p sampling parameter
/// Min-p sampling parameter (SGLang extension)
#[serde(default)]
pub min_p: f32,
/// Repetition penalty
/// Repetition penalty (SGLang extension)
#[serde(default = "default_repetition_penalty")]
pub repetition_penalty: f32,
}
@@ -1491,7 +1487,7 @@ impl Default for ResponsesRequest {
top_p: None,
truncation: None,
user: None,
request_id: generate_request_id(),
request_id: None,
priority: 0,
frequency_penalty: None,
presence_penalty: None,
@@ -1503,101 +1499,6 @@ impl Default for ResponsesRequest {
}
}
impl ResponsesRequest {
/// Default sampling parameters
const DEFAULT_TEMPERATURE: f32 = 0.7;
const DEFAULT_TOP_P: f32 = 1.0;
/// Convert to sampling parameters for generation
pub fn to_sampling_params(
&self,
default_max_tokens: u32,
default_params: Option<HashMap<String, Value>>,
) -> HashMap<String, Value> {
let mut params = HashMap::new();
// Use max_output_tokens if available
let max_tokens = if let Some(max_output) = self.max_output_tokens {
std::cmp::min(max_output, default_max_tokens)
} else {
default_max_tokens
};
// Avoid exceeding context length by minus 1 token
let max_tokens = max_tokens.saturating_sub(1);
// Temperature
let temperature = self.temperature.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("temperature"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TEMPERATURE)
});
// Top-p
let top_p = self.top_p.unwrap_or_else(|| {
default_params
.as_ref()
.and_then(|p| p.get("top_p"))
.and_then(|v| v.as_f64())
.map(|v| v as f32)
.unwrap_or(Self::DEFAULT_TOP_P)
});
params.insert(
"max_new_tokens".to_string(),
Value::Number(Number::from(max_tokens)),
);
params.insert(
"temperature".to_string(),
Value::Number(Number::from_f64(temperature as f64).unwrap()),
);
params.insert(
"top_p".to_string(),
Value::Number(Number::from_f64(top_p as f64).unwrap()),
);
if let Some(fp) = self.frequency_penalty {
params.insert(
"frequency_penalty".to_string(),
Value::Number(Number::from_f64(fp as f64).unwrap()),
);
}
if let Some(pp) = self.presence_penalty {
params.insert(
"presence_penalty".to_string(),
Value::Number(Number::from_f64(pp as f64).unwrap()),
);
}
params.insert("top_k".to_string(), Value::Number(Number::from(self.top_k)));
params.insert(
"min_p".to_string(),
Value::Number(Number::from_f64(self.min_p as f64).unwrap()),
);
params.insert(
"repetition_penalty".to_string(),
Value::Number(Number::from_f64(self.repetition_penalty as f64).unwrap()),
);
if let Some(ref stop) = self.stop {
match to_value(stop) {
Ok(value) => params.insert("stop".to_string(), value),
Err(_) => params.insert("stop".to_string(), Value::Null),
};
}
// Apply any additional default parameters
if let Some(default_params) = default_params {
for (key, value) in default_params {
params.entry(key).or_insert(value);
}
}
params
}
}
impl GenerationRequest for ResponsesRequest {
fn is_stream(&self) -> bool {
self.stream.unwrap_or(false)
@@ -1776,7 +1677,10 @@ impl ResponsesResponse {
usage: Option<UsageInfo>,
) -> Self {
Self {
id: request.request_id.clone(),
id: request
.request_id
.clone()
.expect("request_id should be set by middleware"),
object: "response".to_string(),
created_at: created_time,
status,
@@ -2535,9 +2439,6 @@ pub enum GenerateFinishReason {
Other(Value),
}
// Constants for rerank API
pub const DEFAULT_MODEL_NAME: &str = "default";
/// Rerank request for scoring documents against a query
/// Used for RAG systems and document relevance scoring
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -2549,7 +2450,7 @@ pub struct RerankRequest {
pub documents: Vec<String>,
/// Model to use for reranking
#[serde(default = "default_model_name")]
#[serde(default = "default_model")]
pub model: String,
/// Maximum number of documents to return (optional)
@@ -2567,10 +2468,6 @@ pub struct RerankRequest {
pub user: Option<String>,
}
pub fn default_model_name() -> String {
DEFAULT_MODEL_NAME.to_string()
}
fn default_return_documents() -> bool {
true
}
@@ -2634,7 +2531,7 @@ impl From<V1RerankReqInput> for RerankRequest {
RerankRequest {
query: v1.query,
documents: v1.documents,
model: default_model_name(),
model: default_model(),
top_k: None,
return_documents: true,
rid: None,