Files
sglang/sgl-router/src/routers/http/openai_router.rs
2025-09-23 09:12:02 -07:00

925 lines
33 KiB
Rust

//! OpenAI router implementation
use crate::config::CircuitBreakerConfig;
use crate::core::{CircuitBreaker, CircuitBreakerConfig as CoreCircuitBreakerConfig};
use crate::data_connector::{ResponseId, SharedResponseStorage, StoredResponse};
use crate::protocols::spec::{
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
ResponseContentPart, ResponseInput, ResponseInputOutputItem, ResponseOutputItem,
ResponseStatus, ResponseTextFormat, ResponsesGetParams, ResponsesRequest, ResponsesResponse,
TextFormatType,
};
use crate::routers::header_utils::{apply_request_headers, preserve_response_headers};
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode},
response::{IntoResponse, Response},
};
use futures_util::StreamExt;
use serde_json::{json, to_value, Value};
use std::{
any::Any,
collections::HashMap,
sync::atomic::{AtomicBool, Ordering},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{error, info, warn};
/// Router for OpenAI backend
pub struct OpenAIRouter {
/// HTTP client for upstream OpenAI-compatible API
client: reqwest::Client,
/// Base URL for identification (no trailing slash)
base_url: String,
/// Circuit breaker
circuit_breaker: CircuitBreaker,
/// Health status
healthy: AtomicBool,
/// Response storage for managing conversation history
response_storage: SharedResponseStorage,
}
impl std::fmt::Debug for OpenAIRouter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenAIRouter")
.field("base_url", &self.base_url)
.field("healthy", &self.healthy)
.finish()
}
}
impl OpenAIRouter {
/// Create a new OpenAI router
pub async fn new(
base_url: String,
circuit_breaker_config: Option<CircuitBreakerConfig>,
response_storage: SharedResponseStorage,
) -> Result<Self, String> {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let base_url = base_url.trim_end_matches('/').to_string();
// Convert circuit breaker config
let core_cb_config = circuit_breaker_config
.map(|cb| CoreCircuitBreakerConfig {
failure_threshold: cb.failure_threshold,
success_threshold: cb.success_threshold,
timeout_duration: std::time::Duration::from_secs(cb.timeout_duration_secs),
window_duration: std::time::Duration::from_secs(cb.window_duration_secs),
})
.unwrap_or_default();
let circuit_breaker = CircuitBreaker::with_config(core_cb_config);
Ok(Self {
client,
base_url,
circuit_breaker,
healthy: AtomicBool::new(true),
response_storage,
})
}
async fn handle_non_streaming_response(
&self,
url: String,
headers: Option<&HeaderMap>,
payload: Value,
original_body: &ResponsesRequest,
original_previous_response_id: Option<String>,
) -> Response {
let request_builder = self.client.post(&url).json(&payload);
// Apply headers with filtering
let request_builder = if let Some(headers) = headers {
apply_request_headers(headers, request_builder, true)
} else {
request_builder
};
match request_builder.send().await {
Ok(response) => {
let status = response.status();
if !status.is_success() {
let error_text = response
.text()
.await
.unwrap_or_else(|e| format!("Failed to get error body: {}", e));
return (status, error_text).into_response();
}
// Parse the response
match response.json::<Value>().await {
Ok(mut openai_response_json) => {
if let Some(prev_id) = original_previous_response_id {
if let Some(obj) = openai_response_json.as_object_mut() {
let should_insert = obj
.get("previous_response_id")
.map(|v| v.is_null())
.unwrap_or(true);
if should_insert {
obj.insert(
"previous_response_id".to_string(),
Value::String(prev_id),
);
}
}
}
if let Some(obj) = openai_response_json.as_object_mut() {
if !obj.contains_key("instructions") {
if let Some(instructions) = &original_body.instructions {
obj.insert(
"instructions".to_string(),
Value::String(instructions.clone()),
);
}
}
if !obj.contains_key("metadata") {
if let Some(metadata) = &original_body.metadata {
let metadata_map: serde_json::Map<String, Value> = metadata
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
obj.insert("metadata".to_string(), Value::Object(metadata_map));
}
}
// Reflect the client's requested store preference in the response body
obj.insert("store".to_string(), Value::Bool(original_body.store));
}
if original_body.store {
if let Err(e) = self
.store_response_internal(&openai_response_json, original_body)
.await
{
warn!("Failed to store response: {}", e);
}
}
match serde_json::to_string(&openai_response_json) {
Ok(json_str) => (
StatusCode::OK,
[("content-type", "application/json")],
json_str,
)
.into_response(),
Err(e) => {
error!("Failed to serialize response: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
json!({"error": {"message": "Failed to serialize response", "type": "internal_error"}}).to_string(),
)
.into_response()
}
}
}
Err(e) => {
error!("Failed to parse OpenAI response: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to parse response: {}", e),
)
.into_response()
}
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to forward request to OpenAI: {}", e),
)
.into_response(),
}
}
async fn handle_streaming_response(
&self,
_url: String,
_headers: Option<&HeaderMap>,
_payload: Value,
_original_body: &ResponsesRequest,
_original_previous_response_id: Option<String>,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"Streaming responses not yet implemented",
)
.into_response()
}
async fn store_response_internal(
&self,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<(), String> {
if !original_body.store {
return Ok(());
}
match Self::store_response_impl(&self.response_storage, response_json, original_body).await
{
Ok(response_id) => {
info!(response_id = %response_id.0, "Stored response locally");
Ok(())
}
Err(e) => Err(e),
}
}
async fn store_response_impl(
response_storage: &SharedResponseStorage,
response_json: &Value,
original_body: &ResponsesRequest,
) -> Result<ResponseId, String> {
let input_text = match &original_body.input {
ResponseInput::Text(text) => text.clone(),
ResponseInput::Items(_) => "complex input".to_string(),
};
let output_text = Self::extract_primary_output_text(response_json).unwrap_or_default();
let mut stored_response = StoredResponse::new(input_text, output_text, None);
stored_response.instructions = response_json
.get("instructions")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.instructions.clone());
stored_response.model = response_json
.get("model")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.model.clone());
stored_response.user = response_json
.get("user")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
.or_else(|| original_body.user.clone());
stored_response.metadata = response_json
.get("metadata")
.and_then(|v| v.as_object())
.map(|m| {
m.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<HashMap<_, _>>()
})
.unwrap_or_else(|| original_body.metadata.clone().unwrap_or_default());
stored_response.previous_response_id = response_json
.get("previous_response_id")
.and_then(|v| v.as_str())
.map(|s| ResponseId::from_string(s.to_string()))
.or_else(|| {
original_body
.previous_response_id
.as_ref()
.map(|id| ResponseId::from_string(id.clone()))
});
if let Some(id_str) = response_json.get("id").and_then(|v| v.as_str()) {
stored_response.id = ResponseId::from_string(id_str.to_string());
}
stored_response.raw_response = response_json.clone();
response_storage
.store_response(stored_response)
.await
.map_err(|e| format!("Failed to store response: {}", e))
}
fn extract_primary_output_text(response_json: &Value) -> Option<String> {
if let Some(items) = response_json.get("output").and_then(|v| v.as_array()) {
for item in items {
if let Some(content) = item.get("content").and_then(|v| v.as_array()) {
for part in content {
if part
.get("type")
.and_then(|v| v.as_str())
.map(|t| t == "output_text")
.unwrap_or(false)
{
if let Some(text) = part.get("text").and_then(|v| v.as_str()) {
return Some(text.to_string());
}
}
}
}
}
}
None
}
}
#[async_trait]
impl super::super::RouterTrait for OpenAIRouter {
fn as_any(&self) -> &dyn Any {
self
}
async fn health(&self, _req: Request<Body>) -> Response {
// Simple upstream probe: GET {base}/v1/models without auth
let url = format!("{}/v1/models", self.base_url);
match self
.client
.get(&url)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
{
Ok(resp) => {
let code = resp.status();
// Treat success and auth-required as healthy (endpoint reachable)
if code.is_success() || code.as_u16() == 401 || code.as_u16() == 403 {
(StatusCode::OK, "OK").into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream status: {}", code),
)
.into_response()
}
}
Err(e) => (
StatusCode::SERVICE_UNAVAILABLE,
format!("Upstream error: {}", e),
)
.into_response(),
}
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// For OpenAI, health_generate is the same as health
self.health(_req).await
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
let info = json!({
"router_type": "openai",
"workers": 1,
"base_url": &self.base_url
});
(StatusCode::OK, info.to_string()).into_response()
}
async fn get_models(&self, req: Request<Body>) -> Response {
// Proxy to upstream /v1/models; forward Authorization header if provided
let headers = req.headers();
let mut upstream = self.client.get(format!("{}/v1/models", self.base_url));
if let Some(auth) = headers
.get("authorization")
.or_else(|| headers.get("Authorization"))
{
upstream = upstream.header("Authorization", auth);
}
match upstream.send().await {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let content_type = res.headers().get(CONTENT_TYPE).cloned();
match res.bytes().await {
Ok(body) => {
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read upstream response: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to contact upstream: {}", e),
)
.into_response(),
}
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
// Not directly supported without model param; return 501
(
StatusCode::NOT_IMPLEMENTED,
"get_model_info not implemented for OpenAI router",
)
.into_response()
}
async fn route_generate(
&self,
_headers: Option<&HeaderMap>,
_body: &GenerateRequest,
_model_id: Option<&str>,
) -> Response {
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Generate endpoint not supported for OpenAI backend",
)
.into_response()
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
if !self.circuit_breaker.can_execute() {
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
}
// Serialize request body, removing SGLang-only fields
let mut payload = match to_value(body) {
Ok(v) => v,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to serialize request: {}", e),
)
.into_response();
}
};
if let Some(obj) = payload.as_object_mut() {
for key in [
"top_k",
"min_p",
"min_tokens",
"regex",
"ebnf",
"stop_token_ids",
"no_stop_trim",
"ignore_eos",
"continue_final_message",
"skip_special_tokens",
"lora_path",
"session_params",
"separate_reasoning",
"stream_reasoning",
"chat_template_kwargs",
"return_hidden_states",
"repetition_penalty",
"sampling_seed",
] {
obj.remove(key);
}
}
let url = format!("{}/v1/chat/completions", self.base_url);
let mut req = self.client.post(&url).json(&payload);
// Forward Authorization header if provided
if let Some(h) = headers {
if let Some(auth) = h.get("authorization").or_else(|| h.get("Authorization")) {
req = req.header("Authorization", auth);
}
}
// Accept SSE when stream=true
if body.stream {
req = req.header("Accept", "text/event-stream");
}
let resp = match req.send().await {
Ok(r) => r,
Err(e) => {
self.circuit_breaker.record_failure();
return (
StatusCode::SERVICE_UNAVAILABLE,
format!("Failed to contact upstream: {}", e),
)
.into_response();
}
};
let status = StatusCode::from_u16(resp.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
if !body.stream {
// Capture Content-Type before consuming response body
let content_type = resp.headers().get(CONTENT_TYPE).cloned();
match resp.bytes().await {
Ok(body) => {
self.circuit_breaker.record_success();
let mut response = Response::new(axum::body::Body::from(body));
*response.status_mut() = status;
if let Some(ct) = content_type {
response.headers_mut().insert(CONTENT_TYPE, ct);
}
response
}
Err(e) => {
self.circuit_breaker.record_failure();
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response: {}", e),
)
.into_response()
}
}
} else {
// Stream SSE bytes to client
let stream = resp.bytes_stream();
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
let mut s = stream;
while let Some(chunk) = s.next().await {
match chunk {
Ok(bytes) => {
if tx.send(Ok(bytes)).is_err() {
break;
}
}
Err(e) => {
let _ = tx.send(Err(format!("Stream error: {}", e)));
break;
}
}
}
});
let mut response = Response::new(Body::from_stream(UnboundedReceiverStream::new(rx)));
*response.status_mut() = status;
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
response
}
}
async fn route_completion(
&self,
_headers: Option<&HeaderMap>,
_body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
// Completion endpoint not implemented for OpenAI backend
(
StatusCode::NOT_IMPLEMENTED,
"Completion endpoint not implemented for OpenAI backend",
)
.into_response()
}
async fn route_responses(
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
let url = format!("{}/v1/responses", self.base_url);
info!(
requested_store = body.store,
is_streaming = body.stream,
"openai_responses_request"
);
if body.stream {
return (
StatusCode::NOT_IMPLEMENTED,
"Streaming responses not yet implemented",
)
.into_response();
}
// Clone the body and override model if needed
let mut request_body = body.clone();
if let Some(model) = model_id {
request_body.model = Some(model.to_string());
}
// Store the original previous_response_id for the response
let original_previous_response_id = request_body.previous_response_id.clone();
// Handle previous_response_id by loading prior context
let mut conversation_items: Option<Vec<ResponseInputOutputItem>> = None;
if let Some(prev_id_str) = request_body.previous_response_id.clone() {
let prev_id = ResponseId::from_string(prev_id_str.clone());
match self
.response_storage
.get_response_chain(&prev_id, None)
.await
{
Ok(chain) => {
if !chain.responses.is_empty() {
let mut items = Vec::new();
for stored in chain.responses.iter() {
let trimmed_id = stored.id.0.trim_start_matches("resp_");
if !stored.input.is_empty() {
items.push(ResponseInputOutputItem::Message {
id: format!("msg_u_{}", trimmed_id),
role: "user".to_string(),
status: Some("completed".to_string()),
content: vec![ResponseContentPart::InputText {
text: stored.input.clone(),
}],
});
}
if !stored.output.is_empty() {
items.push(ResponseInputOutputItem::Message {
id: format!("msg_a_{}", trimmed_id),
role: "assistant".to_string(),
status: Some("completed".to_string()),
content: vec![ResponseContentPart::OutputText {
text: stored.output.clone(),
annotations: vec![],
logprobs: None,
}],
});
}
}
conversation_items = Some(items);
} else {
info!(previous_response_id = %prev_id_str, "previous chain empty");
}
}
Err(err) => {
warn!(previous_response_id = %prev_id_str, %err, "failed to fetch previous response chain");
}
}
// Clear previous_response_id from request since we're converting to conversation
request_body.previous_response_id = None;
}
if let Some(mut items) = conversation_items {
match &request_body.input {
ResponseInput::Text(text) => {
items.push(ResponseInputOutputItem::Message {
id: format!("msg_u_current_{}", items.len()),
role: "user".to_string(),
status: Some("completed".to_string()),
content: vec![ResponseContentPart::InputText { text: text.clone() }],
});
}
ResponseInput::Items(existing) => {
items.extend(existing.clone());
}
}
request_body.input = ResponseInput::Items(items);
}
// Always set store=false for OpenAI (we store internally)
request_body.store = false;
// Convert to JSON payload and strip SGLang-specific fields before forwarding
let mut payload = match to_value(&request_body) {
Ok(value) => value,
Err(err) => {
return (
StatusCode::BAD_REQUEST,
format!("Failed to serialize responses request: {}", err),
)
.into_response();
}
};
if let Some(obj) = payload.as_object_mut() {
for key in [
"request_id",
"priority",
"frequency_penalty",
"presence_penalty",
"stop",
"top_k",
"min_p",
"repetition_penalty",
] {
obj.remove(key);
}
}
// Check if streaming is requested
if body.stream {
// Handle streaming response
self.handle_streaming_response(
url,
headers,
payload,
body,
original_previous_response_id,
)
.await
} else {
// Handle non-streaming response
self.handle_non_streaming_response(
url,
headers,
payload,
body,
original_previous_response_id,
)
.await
}
}
async fn get_response(
&self,
_headers: Option<&HeaderMap>,
response_id: &str,
params: &ResponsesGetParams,
) -> Response {
let stored_id = ResponseId::from_string(response_id.to_string());
if let Ok(Some(stored_response)) = self.response_storage.get_response(&stored_id).await {
let stream_requested = params.stream.unwrap_or(false);
let raw_value = stored_response.raw_response.clone();
if !raw_value.is_null() {
if stream_requested {
return (
StatusCode::NOT_IMPLEMENTED,
"Streaming retrieval not yet implemented",
)
.into_response();
}
return (
StatusCode::OK,
[("content-type", "application/json")],
raw_value.to_string(),
)
.into_response();
}
let openai_response = ResponsesResponse {
id: stored_response.id.0.clone(),
object: "response".to_string(),
created_at: stored_response.created_at.timestamp(),
status: ResponseStatus::Completed,
error: None,
incomplete_details: None,
instructions: stored_response.instructions.clone(),
max_output_tokens: None,
model: stored_response
.model
.unwrap_or_else(|| "gpt-4o".to_string()),
output: vec![ResponseOutputItem::Message {
id: format!("msg_{}", stored_response.id.0),
role: "assistant".to_string(),
status: "completed".to_string(),
content: vec![ResponseContentPart::OutputText {
text: stored_response.output,
annotations: vec![],
logprobs: None,
}],
}],
parallel_tool_calls: true,
previous_response_id: stored_response.previous_response_id.map(|id| id.0),
reasoning: None,
store: true,
temperature: Some(1.0),
text: Some(ResponseTextFormat {
format: TextFormatType {
format_type: "text".to_string(),
},
}),
tool_choice: "auto".to_string(),
tools: vec![],
top_p: Some(1.0),
truncation: Some("disabled".to_string()),
usage: None,
user: stored_response.user.clone(),
metadata: stored_response.metadata.clone(),
};
if stream_requested {
return (
StatusCode::NOT_IMPLEMENTED,
"Streaming retrieval not yet implemented",
)
.into_response();
}
return (
StatusCode::OK,
[("content-type", "application/json")],
serde_json::to_string(&openai_response).unwrap_or_else(|e| {
format!("{{\"error\": \"Failed to serialize response: {}\"}}", e)
}),
)
.into_response();
}
(
StatusCode::NOT_FOUND,
format!(
"Response with id '{}' not found in local storage",
response_id
),
)
.into_response()
}
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
// Forward to OpenAI's cancel endpoint
let url = format!("{}/v1/responses/{}/cancel", self.base_url, response_id);
let request_builder = self.client.post(&url);
// Apply headers with filtering (skip content headers for POST without body)
let request_builder = if let Some(headers) = headers {
apply_request_headers(headers, request_builder, true)
} else {
request_builder
};
match request_builder.send().await {
Ok(response) => {
let status = response.status();
let headers = response.headers().clone();
match response.text().await {
Ok(body_text) => {
let mut response = (status, body_text).into_response();
*response.headers_mut() = preserve_response_headers(&headers);
response
}
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to read response body: {}", e),
)
.into_response(),
}
}
Err(e) => (
StatusCode::BAD_GATEWAY,
format!("Failed to cancel response on OpenAI: {}", e),
)
.into_response(),
}
}
async fn flush_cache(&self) -> Response {
(
StatusCode::FORBIDDEN,
"flush_cache not supported for OpenAI router",
)
.into_response()
}
async fn get_worker_loads(&self) -> Response {
(
StatusCode::FORBIDDEN,
"get_worker_loads not supported for OpenAI router",
)
.into_response()
}
fn router_type(&self) -> &'static str {
"openai"
}
fn readiness(&self) -> Response {
if self.healthy.load(Ordering::Acquire) && self.circuit_breaker.can_execute() {
(StatusCode::OK, "Ready").into_response()
} else {
(StatusCode::SERVICE_UNAVAILABLE, "Not ready").into_response()
}
}
async fn route_embeddings(
&self,
_headers: Option<&HeaderMap>,
_body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::FORBIDDEN,
"Embeddings endpoint not supported for OpenAI backend",
)
.into_response()
}
async fn route_rerank(
&self,
_headers: Option<&HeaderMap>,
_body: &RerankRequest,
_model_id: Option<&str>,
) -> Response {
(
StatusCode::FORBIDDEN,
"Rerank endpoint not supported for OpenAI backend",
)
.into_response()
}
}