[router][protocols] Add Axum validate extractor and use it for /v1/chat/completions endpoint (#11588)

This commit is contained in:
Chang Su
2025-10-13 22:51:15 -07:00
committed by GitHub
parent e4358a4585
commit 27ef1459e6
21 changed files with 1982 additions and 2003 deletions

View File

@@ -301,13 +301,7 @@ impl SglangSchedulerClient {
) -> Result<proto::SamplingParams, String> {
let stop_sequences = self.extract_stop_strings(request);
// Handle max tokens: prefer max_completion_tokens (new) over max_tokens (deprecated)
// If neither is specified, use None to let the backend decide the default
#[allow(deprecated)]
let max_new_tokens = request
.max_completion_tokens
.or(request.max_tokens)
.map(|v| v as i32);
let max_new_tokens = request.max_completion_tokens.map(|v| v as i32);
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
let skip_special_tokens = if request.tools.is_some() {
@@ -322,7 +316,6 @@ impl SglangSchedulerClient {
request.skip_special_tokens
};
#[allow(deprecated)]
Ok(proto::SamplingParams {
temperature: request.temperature.unwrap_or(1.0),
top_p: request.top_p.unwrap_or(1.0),
@@ -485,10 +478,10 @@ impl SglangSchedulerClient {
})?);
}
// Handle min_tokens with conversion
if let Some(min_tokens) = p.min_tokens {
sampling.min_new_tokens = i32::try_from(min_tokens)
.map_err(|_| "min_tokens must fit into a 32-bit signed integer".to_string())?;
// Handle min_new_tokens with conversion
if let Some(min_new_tokens) = p.min_new_tokens {
sampling.min_new_tokens = i32::try_from(min_new_tokens)
.map_err(|_| "min_new_tokens must fit into a 32-bit signed integer".to_string())?;
}
// Handle n with conversion

View File

@@ -2,5 +2,5 @@
// This module provides a structured approach to handling different API protocols
pub mod spec;
pub mod validation;
pub mod validated;
pub mod worker_spec;

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
// Validated JSON extractor for automatic request validation
//
// This module provides a ValidatedJson extractor that automatically validates
// requests using the validator crate's Validate trait.
use axum::{
extract::{rejection::JsonRejection, FromRequest, Request},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::de::DeserializeOwned;
use serde_json::json;
use validator::Validate;
/// Trait for request types that need post-deserialization normalization
pub trait Normalizable {
/// Normalize the request by applying defaults and transformations
fn normalize(&mut self) {
// Default: no-op
}
}
/// A JSON extractor that automatically validates and normalizes the request body
///
/// This extractor deserializes the request body and automatically calls `.validate()`
/// on types that implement the `Validate` trait. If validation fails, it returns
/// a 400 Bad Request with detailed error information.
///
/// # Example
///
/// ```rust,ignore
/// async fn create_chat(
/// ValidatedJson(request): ValidatedJson<ChatCompletionRequest>,
/// ) -> Response {
/// // request is guaranteed to be valid here
/// process_request(request).await
/// }
/// ```
pub struct ValidatedJson<T>(pub T);
impl<S, T> FromRequest<S> for ValidatedJson<T>
where
T: DeserializeOwned + Validate + Normalizable + Send,
S: Send + Sync,
{
type Rejection = Response;
async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
// First, extract and deserialize the JSON
let Json(mut data) =
Json::<T>::from_request(req, state)
.await
.map_err(|err: JsonRejection| {
let error_message = match err {
JsonRejection::JsonDataError(e) => {
format!("Invalid JSON data: {}", e)
}
JsonRejection::JsonSyntaxError(e) => {
format!("JSON syntax error: {}", e)
}
JsonRejection::MissingJsonContentType(_) => {
"Missing Content-Type: application/json header".to_string()
}
_ => format!("Failed to parse JSON: {}", err),
};
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": error_message,
"type": "invalid_request_error",
"code": "json_parse_error"
}
})),
)
.into_response()
})?;
// Normalize the request (apply defaults based on other fields)
data.normalize();
// Then, automatically validate the data
data.validate().map_err(|validation_errors| {
// Extract the first error message from the validation errors
let error_message = validation_errors
.field_errors()
.values()
.flat_map(|errors| errors.iter())
.find_map(|e| e.message.as_ref())
.map(|m| m.to_string())
.unwrap_or_else(|| "Validation failed".to_string());
(
StatusCode::BAD_REQUEST,
Json(json!({
"error": {
"message": error_message,
"type": "invalid_request_error",
"code": 400
}
})),
)
.into_response()
})?;
Ok(ValidatedJson(data))
}
}
// Implement Deref to allow transparent access to the inner value
impl<T> std::ops::Deref for ValidatedJson<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::DerefMut for ValidatedJson<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde::{Deserialize, Serialize};
use validator::Validate;
#[derive(Debug, Deserialize, Serialize, Validate)]
struct TestRequest {
#[validate(range(min = 0.0, max = 1.0))]
value: f32,
#[validate(length(min = 1))]
name: String,
}
impl Normalizable for TestRequest {
// Use default no-op implementation
}
#[tokio::test]
async fn test_validated_json_valid() {
// This test is conceptual - actual testing would require Axum test harness
let request = TestRequest {
value: 0.5,
name: "test".to_string(),
};
assert!(request.validate().is_ok());
}
#[tokio::test]
async fn test_validated_json_invalid_range() {
let request = TestRequest {
value: 1.5, // Out of range
name: "test".to_string(),
};
assert!(request.validate().is_err());
}
#[tokio::test]
async fn test_validated_json_invalid_length() {
let request = TestRequest {
value: 0.5,
name: "".to_string(), // Empty name
};
assert!(request.validate().is_err());
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -959,7 +959,6 @@ mod tests {
#[test]
fn test_transform_messages_string_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Hello".to_string(),
@@ -993,7 +992,6 @@ mod tests {
#[test]
fn test_transform_messages_openai_format() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "Describe this image:".to_string(),
@@ -1028,7 +1026,6 @@ mod tests {
#[test]
fn test_transform_messages_simple_string_content() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Simple text message".to_string()),
name: None,
}];
@@ -1049,12 +1046,10 @@ mod tests {
fn test_transform_messages_multiple_messages() {
let messages = vec![
ChatMessage::System {
role: "system".to_string(),
content: "System prompt".to_string(),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "User message".to_string(),
@@ -1086,7 +1081,6 @@ mod tests {
#[test]
fn test_transform_messages_empty_text_parts() {
let messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![ContentPart::ImageUrl {
image_url: ImageUrl {
url: "https://example.com/image.jpg".to_string(),
@@ -1109,12 +1103,10 @@ mod tests {
fn test_transform_messages_mixed_content_types() {
let messages = vec![
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Plain text".to_string()),
name: None,
},
ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Parts(vec![
ContentPart::Text {
text: "With image".to_string(),

View File

@@ -16,6 +16,7 @@ use crate::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest,
RerankRequest, ResponsesGetParams, ResponsesRequest, V1RerankReqInput,
},
validated::ValidatedJson,
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
},
reasoning_parser::ParserFactory as ReasoningParserFactory,
@@ -291,7 +292,7 @@ async fn generate(
async fn v1_chat_completions(
State(state): State<Arc<AppState>>,
headers: http::HeaderMap,
Json(body): Json<ChatCompletionRequest>,
ValidatedJson(body): ValidatedJson<ChatCompletionRequest>,
) -> Response {
state.router.route_chat(Some(&headers), &body, None).await
}