[router]restructure protocol modules for better organization (#9321)
This commit is contained in:
@@ -3,9 +3,13 @@ use serde_json::{from_str, to_string, to_value, to_vec};
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
||||||
use sglang_router_rs::openai_api_types::{
|
use sglang_router_rs::protocols::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
common::StringOrArray,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
generate::{GenerateParameters, GenerateRequest, SamplingParams},
|
||||||
|
openai::{
|
||||||
|
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
||||||
|
completions::CompletionRequest,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
|
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,8 @@ use std::collections::HashMap;
|
|||||||
pub mod core;
|
pub mod core;
|
||||||
pub mod metrics;
|
pub mod metrics;
|
||||||
pub mod middleware;
|
pub mod middleware;
|
||||||
pub mod openai_api_types;
|
|
||||||
pub mod policies;
|
pub mod policies;
|
||||||
|
pub mod protocols;
|
||||||
pub mod reasoning_parser;
|
pub mod reasoning_parser;
|
||||||
pub mod routers;
|
pub mod routers;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
|
|||||||
@@ -1,921 +0,0 @@
|
|||||||
// OpenAI-compatible API types for text generation
|
|
||||||
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
|
|
||||||
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
/// Helper function for serde default value
|
|
||||||
fn default_true() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= SGLang-Specific Types =============
|
|
||||||
|
|
||||||
/// LoRA adapter path - can be single path or batch of paths
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum LoRAPath {
|
|
||||||
Single(Option<String>),
|
|
||||||
Batch(Vec<Option<String>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Common trait for all generation requests
|
|
||||||
pub trait GenerationRequest: Send + Sync {
|
|
||||||
/// Check if the request is for streaming
|
|
||||||
fn is_stream(&self) -> bool;
|
|
||||||
|
|
||||||
/// Get the model name if specified
|
|
||||||
fn get_model(&self) -> Option<&str>;
|
|
||||||
|
|
||||||
/// Extract text content for routing decisions
|
|
||||||
fn extract_text_for_routing(&self) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionRequest {
|
|
||||||
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
|
|
||||||
pub model: String,
|
|
||||||
|
|
||||||
/// The prompt(s) to generate completions for
|
|
||||||
pub prompt: StringOrArray,
|
|
||||||
|
|
||||||
/// The suffix that comes after a completion of inserted text
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub suffix: Option<String>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// An alternative to sampling with temperature (nucleus sampling)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// How many completions to generate for each prompt
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
|
|
||||||
/// Whether to stream back partial progress
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Options for streaming response
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
|
|
||||||
/// Include the log probabilities on the logprobs most likely tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<u32>,
|
|
||||||
|
|
||||||
/// Echo back the prompt in addition to the completion
|
|
||||||
#[serde(default)]
|
|
||||||
pub echo: bool,
|
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Generates best_of completions server-side and returns the "best"
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub best_of: Option<u32>,
|
|
||||||
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logit_bias: Option<HashMap<String, f32>>,
|
|
||||||
|
|
||||||
/// A unique identifier representing your end-user
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub user: Option<String>,
|
|
||||||
|
|
||||||
/// If specified, our system will make a best effort to sample deterministically
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub seed: Option<i64>,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Top-k sampling parameter (-1 to disable)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<i32>,
|
|
||||||
|
|
||||||
/// Min-p nucleus sampling parameter
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_p: Option<f32>,
|
|
||||||
|
|
||||||
/// Minimum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Repetition penalty for reducing repetitive text
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Regex constraint for output generation
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub regex: Option<String>,
|
|
||||||
|
|
||||||
/// EBNF grammar constraint for structured output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ebnf: Option<String>,
|
|
||||||
|
|
||||||
/// JSON schema constraint for structured output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub json_schema: Option<String>,
|
|
||||||
|
|
||||||
/// Specific token IDs to use as stop conditions
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
|
||||||
|
|
||||||
/// Skip trimming stop tokens from output
|
|
||||||
#[serde(default)]
|
|
||||||
pub no_stop_trim: bool,
|
|
||||||
|
|
||||||
/// Ignore end-of-sequence tokens during generation
|
|
||||||
#[serde(default)]
|
|
||||||
pub ignore_eos: bool,
|
|
||||||
|
|
||||||
/// Skip special tokens during detokenization
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub skip_special_tokens: bool,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_hidden_states: bool,
|
|
||||||
|
|
||||||
/// Additional fields including bootstrap info for PD routing
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub other: serde_json::Map<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for CompletionRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
Some(&self.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
match &self.prompt {
|
|
||||||
StringOrArray::String(s) => s.clone(),
|
|
||||||
StringOrArray::Array(v) => v.join(" "),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Chat Completions API (v1/chat/completions) =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
/// ID of the model to use
|
|
||||||
pub model: String,
|
|
||||||
|
|
||||||
/// A list of messages comprising the conversation so far
|
|
||||||
pub messages: Vec<ChatMessage>,
|
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// An alternative to sampling with temperature
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// How many chat completion choices to generate for each input message
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
|
|
||||||
/// If set, partial message deltas will be sent
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Options for streaming response
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// An upper bound for the number of tokens that can be generated for a completion
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_completion_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logit_bias: Option<HashMap<String, f32>>,
|
|
||||||
|
|
||||||
/// A unique identifier representing your end-user
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub user: Option<String>,
|
|
||||||
|
|
||||||
/// If specified, our system will make a best effort to sample deterministically
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub seed: Option<i64>,
|
|
||||||
|
|
||||||
/// Whether to return log probabilities of the output tokens
|
|
||||||
#[serde(default)]
|
|
||||||
pub logprobs: bool,
|
|
||||||
|
|
||||||
/// An integer between 0 and 20 specifying the number of most likely tokens to return
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_logprobs: Option<u32>,
|
|
||||||
|
|
||||||
/// An object specifying the format that the model must output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub response_format: Option<ResponseFormat>,
|
|
||||||
|
|
||||||
/// A list of tools the model may call
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
|
|
||||||
/// Controls which (if any) tool is called by the model
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_choice: Option<ToolChoice>,
|
|
||||||
|
|
||||||
/// Whether to enable parallel function calling during tool use
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub parallel_tool_calls: Option<bool>,
|
|
||||||
|
|
||||||
/// Deprecated: use tools instead
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub functions: Option<Vec<Function>>,
|
|
||||||
|
|
||||||
/// Deprecated: use tool_choice instead
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function_call: Option<FunctionCall>,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Top-k sampling parameter (-1 to disable)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<i32>,
|
|
||||||
|
|
||||||
/// Min-p nucleus sampling parameter
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_p: Option<f32>,
|
|
||||||
|
|
||||||
/// Minimum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Repetition penalty for reducing repetitive text
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Regex constraint for output generation
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub regex: Option<String>,
|
|
||||||
|
|
||||||
/// EBNF grammar constraint for structured output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ebnf: Option<String>,
|
|
||||||
|
|
||||||
/// Specific token IDs to use as stop conditions
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
|
||||||
|
|
||||||
/// Skip trimming stop tokens from output
|
|
||||||
#[serde(default)]
|
|
||||||
pub no_stop_trim: bool,
|
|
||||||
|
|
||||||
/// Ignore end-of-sequence tokens during generation
|
|
||||||
#[serde(default)]
|
|
||||||
pub ignore_eos: bool,
|
|
||||||
|
|
||||||
/// Continue generating from final assistant message
|
|
||||||
#[serde(default)]
|
|
||||||
pub continue_final_message: bool,
|
|
||||||
|
|
||||||
/// Skip special tokens during detokenization
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub skip_special_tokens: bool,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Separate reasoning content from final answer (O1-style models)
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub separate_reasoning: bool,
|
|
||||||
|
|
||||||
/// Stream reasoning tokens during generation
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub stream_reasoning: bool,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_hidden_states: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ChatMessage {
|
|
||||||
System {
|
|
||||||
role: String, // "system"
|
|
||||||
content: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
name: Option<String>,
|
|
||||||
},
|
|
||||||
User {
|
|
||||||
role: String, // "user"
|
|
||||||
content: UserMessageContent,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
name: Option<String>,
|
|
||||||
},
|
|
||||||
Assistant {
|
|
||||||
role: String, // "assistant"
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
function_call: Option<FunctionCallResponse>,
|
|
||||||
/// Reasoning content for O1-style models (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
reasoning_content: Option<String>,
|
|
||||||
},
|
|
||||||
Tool {
|
|
||||||
role: String, // "tool"
|
|
||||||
content: String,
|
|
||||||
tool_call_id: String,
|
|
||||||
},
|
|
||||||
Function {
|
|
||||||
role: String, // "function"
|
|
||||||
content: String,
|
|
||||||
name: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum UserMessageContent {
|
|
||||||
Text(String),
|
|
||||||
Parts(Vec<ContentPart>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum ContentPart {
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text { text: String },
|
|
||||||
#[serde(rename = "image_url")]
|
|
||||||
ImageUrl { image_url: ImageUrl },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ImageUrl {
|
|
||||||
pub url: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub detail: Option<String>, // "auto", "low", or "high"
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct StreamOptions {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub include_usage: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum ResponseFormat {
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text,
|
|
||||||
#[serde(rename = "json_object")]
|
|
||||||
JsonObject,
|
|
||||||
#[serde(rename = "json_schema")]
|
|
||||||
JsonSchema { json_schema: JsonSchemaFormat },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct JsonSchemaFormat {
|
|
||||||
pub name: String,
|
|
||||||
pub schema: Value,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub strict: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Tool {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: String, // "function"
|
|
||||||
pub function: Function,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Function {
|
|
||||||
pub name: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub parameters: Value, // JSON Schema
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ToolChoice {
|
|
||||||
None,
|
|
||||||
Auto,
|
|
||||||
Required,
|
|
||||||
Function {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
tool_type: String, // "function"
|
|
||||||
function: FunctionChoice,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FunctionChoice {
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ToolCall {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: String, // "function"
|
|
||||||
pub function: FunctionCallResponse,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum FunctionCall {
|
|
||||||
None,
|
|
||||||
Auto,
|
|
||||||
Function { name: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FunctionCallResponse {
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: String, // JSON string
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for ChatCompletionRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
Some(&self.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
// Extract text from messages for routing decisions
|
|
||||||
self.messages
|
|
||||||
.iter()
|
|
||||||
.filter_map(|msg| match msg {
|
|
||||||
ChatMessage::System { content, .. } => Some(content.clone()),
|
|
||||||
ChatMessage::User { content, .. } => match content {
|
|
||||||
UserMessageContent::Text(text) => Some(text.clone()),
|
|
||||||
UserMessageContent::Parts(parts) => {
|
|
||||||
let texts: Vec<String> = parts
|
|
||||||
.iter()
|
|
||||||
.filter_map(|part| match part {
|
|
||||||
ContentPart::Text { text } => Some(text.clone()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Some(texts.join(" "))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ChatMessage::Assistant {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
// Combine content and reasoning content for routing decisions
|
|
||||||
let main_content = content.clone().unwrap_or_default();
|
|
||||||
let reasoning = reasoning_content.clone().unwrap_or_default();
|
|
||||||
if main_content.is_empty() && reasoning.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(format!("{} {}", main_content, reasoning).trim().to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
|
||||||
ChatMessage::Function { content, .. } => Some(content.clone()),
|
|
||||||
})
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Generate API (/generate) =============
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct GenerateRequest {
|
|
||||||
/// The prompt to generate from (OpenAI style)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prompt: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// Text input - SGLang native format
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub text: Option<String>,
|
|
||||||
|
|
||||||
/// Input IDs for tokenized input
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub input_ids: Option<InputIds>,
|
|
||||||
|
|
||||||
/// Generation parameters
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub parameters: Option<GenerateParameters>,
|
|
||||||
|
|
||||||
/// Sampling parameters (sglang style)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub sampling_params: Option<SamplingParams>,
|
|
||||||
|
|
||||||
/// Whether to stream the response
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Whether to return logprobs
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_logprob: bool,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_hidden_states: bool,
|
|
||||||
|
|
||||||
/// Request ID for tracking
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub rid: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum InputIds {
|
|
||||||
Single(Vec<i32>),
|
|
||||||
Batch(Vec<Vec<i32>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
|
||||||
pub struct GenerateParameters {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub best_of: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub decoder_input_details: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub details: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub do_sample: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_new_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub return_full_text: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub seed: Option<u64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub truncate: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub typical_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub watermark: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
|
||||||
pub struct SamplingParams {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_new_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<i32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<StringOrArray>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ignore_eos: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub skip_special_tokens: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub json_schema: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub regex: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ebnf: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub no_stop_trim: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for GenerateRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
// Generate requests typically don't have a model field
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
// Check fields in priority order: text, prompt, inputs
|
|
||||||
if let Some(ref text) = self.text {
|
|
||||||
return text.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref prompt) = self.prompt {
|
|
||||||
return match prompt {
|
|
||||||
StringOrArray::String(s) => s.clone(),
|
|
||||||
StringOrArray::Array(v) => v.join(" "),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref input_ids) = self.input_ids {
|
|
||||||
return match input_ids {
|
|
||||||
InputIds::Single(ids) => ids
|
|
||||||
.iter()
|
|
||||||
.map(|&id| id.to_string())
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" "),
|
|
||||||
InputIds::Batch(batches) => batches
|
|
||||||
.iter()
|
|
||||||
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" "),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// No text input found
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Helper Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum StringOrArray {
|
|
||||||
String(String),
|
|
||||||
Array(Vec<String>),
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Response Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "text_completion"
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<CompletionChoice>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionChoice {
|
|
||||||
pub text: String,
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<LogProbs>,
|
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
|
||||||
/// Information about which stop condition was matched
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
|
||||||
/// Hidden states from the model (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub hidden_states: Option<Vec<f32>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct LogProbs {
|
|
||||||
pub tokens: Vec<String>,
|
|
||||||
pub token_logprobs: Vec<Option<f32>>,
|
|
||||||
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
|
||||||
pub text_offset: Vec<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "chat.completion"
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatChoice>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub message: ChatMessage,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
|
||||||
/// Information about which stop condition was matched
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
|
||||||
/// Hidden states from the model (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub hidden_states: Option<Vec<f32>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatLogProbs {
|
|
||||||
pub content: Option<Vec<ChatLogProbsContent>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatLogProbsContent {
|
|
||||||
pub token: String,
|
|
||||||
pub logprob: f32,
|
|
||||||
pub bytes: Option<Vec<u8>>,
|
|
||||||
pub top_logprobs: Vec<TopLogProb>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct TopLogProb {
|
|
||||||
pub token: String,
|
|
||||||
pub logprob: f32,
|
|
||||||
pub bytes: Option<Vec<u8>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionTokensDetails {
|
|
||||||
pub reasoning_tokens: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Streaming Response Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "text_completion"
|
|
||||||
pub created: u64,
|
|
||||||
pub choices: Vec<CompletionStreamChoice>,
|
|
||||||
pub model: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionStreamChoice {
|
|
||||||
pub text: String,
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<LogProbs>,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "chat.completion.chunk"
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
pub choices: Vec<ChatStreamChoice>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatStreamChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ChatMessageDelta,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatMessageDelta {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub role: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function_call: Option<FunctionCallDelta>,
|
|
||||||
/// Reasoning content delta for O1-style models (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ToolCallDelta {
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub id: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function: Option<FunctionCallDelta>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FunctionCallDelta {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub arguments: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Error Response Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ErrorResponse {
|
|
||||||
pub error: ErrorDetail,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ErrorDetail {
|
|
||||||
pub message: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub error_type: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub param: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub code: Option<String>,
|
|
||||||
}
|
|
||||||
36
sgl-router/src/protocols/common.rs
Normal file
36
sgl-router/src/protocols/common.rs
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
// Common types shared across all protocol implementations
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// Helper function for serde default value
|
||||||
|
pub fn default_true() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Common trait for all generation requests across different APIs
|
||||||
|
pub trait GenerationRequest: Send + Sync {
|
||||||
|
/// Check if the request is for streaming
|
||||||
|
fn is_stream(&self) -> bool;
|
||||||
|
|
||||||
|
/// Get the model name if specified
|
||||||
|
fn get_model(&self) -> Option<&str>;
|
||||||
|
|
||||||
|
/// Extract text content for routing decisions
|
||||||
|
fn extract_text_for_routing(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper type for string or array of strings
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum StringOrArray {
|
||||||
|
String(String),
|
||||||
|
Array(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum LoRAPath {
|
||||||
|
Single(Option<String>),
|
||||||
|
Batch(Vec<Option<String>>),
|
||||||
|
}
|
||||||
8
sgl-router/src/protocols/generate/mod.rs
Normal file
8
sgl-router/src/protocols/generate/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
// SGLang native Generate API module (/generate)
|
||||||
|
|
||||||
|
pub mod request;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
// Re-export main types for convenience
|
||||||
|
pub use request::GenerateRequest;
|
||||||
|
pub use types::{GenerateParameters, InputIds, SamplingParams};
|
||||||
97
sgl-router/src/protocols/generate/request.rs
Normal file
97
sgl-router/src/protocols/generate/request.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
// Generate API request types (/generate)
|
||||||
|
|
||||||
|
use crate::protocols::common::{GenerationRequest, LoRAPath, StringOrArray};
|
||||||
|
use crate::protocols::generate::types::{GenerateParameters, InputIds, SamplingParams};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateRequest {
|
||||||
|
/// The prompt to generate from (OpenAI style)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt: Option<StringOrArray>,
|
||||||
|
|
||||||
|
/// Text input - SGLang native format
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub text: Option<String>,
|
||||||
|
|
||||||
|
/// Input IDs for tokenized input
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub input_ids: Option<InputIds>,
|
||||||
|
|
||||||
|
/// Generation parameters
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
|
||||||
|
/// Sampling parameters (sglang style)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sampling_params: Option<SamplingParams>,
|
||||||
|
|
||||||
|
/// Whether to stream the response
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Whether to return logprobs
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_logprob: bool,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<LoRAPath>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||||
|
|
||||||
|
/// Return model hidden states
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
|
/// Request ID for tracking
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rid: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerationRequest for GenerateRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model(&self) -> Option<&str> {
|
||||||
|
// Generate requests typically don't have a model field
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
// Check fields in priority order: text, prompt, inputs
|
||||||
|
if let Some(ref text) = self.text {
|
||||||
|
return text.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref prompt) = self.prompt {
|
||||||
|
return match prompt {
|
||||||
|
StringOrArray::String(s) => s.clone(),
|
||||||
|
StringOrArray::Array(v) => v.join(" "),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref input_ids) = self.input_ids {
|
||||||
|
return match input_ids {
|
||||||
|
InputIds::Single(ids) => ids
|
||||||
|
.iter()
|
||||||
|
.map(|&id| id.to_string())
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(" "),
|
||||||
|
InputIds::Batch(batches) => batches
|
||||||
|
.iter()
|
||||||
|
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(" "),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// No text input found
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
82
sgl-router/src/protocols/generate/types.rs
Normal file
82
sgl-router/src/protocols/generate/types.rs
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
// Types for the SGLang native /generate API
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum InputIds {
|
||||||
|
Single(Vec<i32>),
|
||||||
|
Batch(Vec<Vec<i32>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||||
|
pub struct GenerateParameters {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub decoder_input_details: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub details: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub do_sample: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_new_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub return_full_text: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub truncate: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub typical_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub watermark: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||||
|
pub struct SamplingParams {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_new_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<crate::protocols::common::StringOrArray>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ignore_eos: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub skip_special_tokens: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub regex: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ebnf: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub no_stop_trim: Option<bool>,
|
||||||
|
}
|
||||||
6
sgl-router/src/protocols/mod.rs
Normal file
6
sgl-router/src/protocols/mod.rs
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
// Protocol definitions and validation for various LLM APIs
|
||||||
|
// This module provides a structured approach to handling different API protocols
|
||||||
|
|
||||||
|
pub mod common;
|
||||||
|
pub mod generate;
|
||||||
|
pub mod openai;
|
||||||
12
sgl-router/src/protocols/openai/chat/mod.rs
Normal file
12
sgl-router/src/protocols/openai/chat/mod.rs
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
// Chat Completions API module
|
||||||
|
|
||||||
|
pub mod request;
|
||||||
|
pub mod response;
|
||||||
|
pub mod types;
|
||||||
|
|
||||||
|
// Re-export main types for convenience
|
||||||
|
pub use request::ChatCompletionRequest;
|
||||||
|
pub use response::{
|
||||||
|
ChatChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice,
|
||||||
|
};
|
||||||
|
pub use types::*;
|
||||||
216
sgl-router/src/protocols/openai/chat/request.rs
Normal file
216
sgl-router/src/protocols/openai/chat/request.rs
Normal file
@@ -0,0 +1,216 @@
|
|||||||
|
// Chat Completions API request types
|
||||||
|
|
||||||
|
use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray};
|
||||||
|
use crate::protocols::openai::chat::types::*;
|
||||||
|
use crate::protocols::openai::common::StreamOptions;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
/// ID of the model to use
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// A list of messages comprising the conversation so far
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// How many chat completion choices to generate for each input message
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// If set, partial message deltas will be sent
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Options for streaming response
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StringOrArray>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// An upper bound for the number of tokens that can be generated for a completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_completion_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<HashMap<String, f32>>,
|
||||||
|
|
||||||
|
/// A unique identifier representing your end-user
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
|
||||||
|
/// If specified, our system will make a best effort to sample deterministically
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
|
||||||
|
/// Whether to return log probabilities of the output tokens
|
||||||
|
#[serde(default)]
|
||||||
|
pub logprobs: bool,
|
||||||
|
|
||||||
|
/// An integer between 0 and 20 specifying the number of most likely tokens to return
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// An object specifying the format that the model must output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
|
||||||
|
/// A list of tools the model may call
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// Controls which (if any) tool is called by the model
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<ToolChoice>,
|
||||||
|
|
||||||
|
/// Whether to enable parallel function calling during tool use
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
|
||||||
|
/// Deprecated: use tools instead
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub functions: Option<Vec<Function>>,
|
||||||
|
|
||||||
|
/// Deprecated: use tool_choice instead
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function_call: Option<FunctionCall>,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Top-k sampling parameter (-1 to disable)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
|
||||||
|
/// Min-p nucleus sampling parameter
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Minimum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Repetition penalty for reducing repetitive text
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Regex constraint for output generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub regex: Option<String>,
|
||||||
|
|
||||||
|
/// EBNF grammar constraint for structured output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ebnf: Option<String>,
|
||||||
|
|
||||||
|
/// Specific token IDs to use as stop conditions
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
|
|
||||||
|
/// Skip trimming stop tokens from output
|
||||||
|
#[serde(default)]
|
||||||
|
pub no_stop_trim: bool,
|
||||||
|
|
||||||
|
/// Ignore end-of-sequence tokens during generation
|
||||||
|
#[serde(default)]
|
||||||
|
pub ignore_eos: bool,
|
||||||
|
|
||||||
|
/// Continue generating from final assistant message
|
||||||
|
#[serde(default)]
|
||||||
|
pub continue_final_message: bool,
|
||||||
|
|
||||||
|
/// Skip special tokens during detokenization
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub skip_special_tokens: bool,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<LoRAPath>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||||
|
|
||||||
|
/// Separate reasoning content from final answer (O1-style models)
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub separate_reasoning: bool,
|
||||||
|
|
||||||
|
/// Stream reasoning tokens during generation
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub stream_reasoning: bool,
|
||||||
|
|
||||||
|
/// Return model hidden states
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_hidden_states: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerationRequest for ChatCompletionRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model(&self) -> Option<&str> {
|
||||||
|
Some(&self.model)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
// Extract text from messages for routing decisions
|
||||||
|
self.messages
|
||||||
|
.iter()
|
||||||
|
.filter_map(|msg| match msg {
|
||||||
|
ChatMessage::System { content, .. } => Some(content.clone()),
|
||||||
|
ChatMessage::User { content, .. } => match content {
|
||||||
|
UserMessageContent::Text(text) => Some(text.clone()),
|
||||||
|
UserMessageContent::Parts(parts) => {
|
||||||
|
let texts: Vec<String> = parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|part| match part {
|
||||||
|
ContentPart::Text { text } => Some(text.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Some(texts.join(" "))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ChatMessage::Assistant {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
// Combine content and reasoning content for routing decisions
|
||||||
|
let main_content = content.clone().unwrap_or_default();
|
||||||
|
let reasoning = reasoning_content.clone().unwrap_or_default();
|
||||||
|
if main_content.is_empty() && reasoning.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(format!("{} {}", main_content, reasoning).trim().to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
||||||
|
ChatMessage::Function { content, .. } => Some(content.clone()),
|
||||||
|
})
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(" ")
|
||||||
|
}
|
||||||
|
}
|
||||||
59
sgl-router/src/protocols/openai/chat/response.rs
Normal file
59
sgl-router/src/protocols/openai/chat/response.rs
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
// Chat Completions API response types
|
||||||
|
|
||||||
|
use crate::protocols::openai::chat::types::{ChatMessage, ChatMessageDelta};
|
||||||
|
use crate::protocols::openai::common::{ChatLogProbs, Usage};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// ============= Regular Response =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "chat.completion"
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: ChatMessage,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
|
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||||
|
/// Information about which stop condition was matched
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
||||||
|
/// Hidden states from the model (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Streaming Response =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionStreamResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "chat.completion.chunk"
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
pub choices: Vec<ChatStreamChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatStreamChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ChatMessageDelta,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
185
sgl-router/src/protocols/openai/chat/types.rs
Normal file
185
sgl-router/src/protocols/openai/chat/types.rs
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
// Types specific to the Chat Completions API
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// ============= Message Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ChatMessage {
|
||||||
|
System {
|
||||||
|
role: String, // "system"
|
||||||
|
content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
User {
|
||||||
|
role: String, // "user"
|
||||||
|
content: UserMessageContent,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
Assistant {
|
||||||
|
role: String, // "assistant"
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
function_call: Option<FunctionCallResponse>,
|
||||||
|
/// Reasoning content for O1-style models (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
},
|
||||||
|
Tool {
|
||||||
|
role: String, // "tool"
|
||||||
|
content: String,
|
||||||
|
tool_call_id: String,
|
||||||
|
},
|
||||||
|
Function {
|
||||||
|
role: String, // "function"
|
||||||
|
content: String,
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum UserMessageContent {
|
||||||
|
Text(String),
|
||||||
|
Parts(Vec<ContentPart>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ContentPart {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text { text: String },
|
||||||
|
#[serde(rename = "image_url")]
|
||||||
|
ImageUrl { image_url: ImageUrl },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ImageUrl {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub detail: Option<String>, // "auto", "low", or "high"
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Response Format Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ResponseFormat {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text,
|
||||||
|
#[serde(rename = "json_object")]
|
||||||
|
JsonObject,
|
||||||
|
#[serde(rename = "json_schema")]
|
||||||
|
JsonSchema { json_schema: JsonSchemaFormat },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct JsonSchemaFormat {
|
||||||
|
pub name: String,
|
||||||
|
pub schema: Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub strict: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Tool/Function Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Tool {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String, // "function"
|
||||||
|
pub function: Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Function {
|
||||||
|
pub name: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Value, // JSON Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ToolChoice {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
Required,
|
||||||
|
Function {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
tool_type: String, // "function"
|
||||||
|
function: FunctionChoice,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FunctionChoice {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String, // "function"
|
||||||
|
pub function: FunctionCallResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum FunctionCall {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
Function { name: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FunctionCallResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String, // JSON string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Streaming Delta Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatMessageDelta {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function_call: Option<FunctionCallDelta>,
|
||||||
|
/// Reasoning content delta for O1-style models (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ToolCallDelta {
|
||||||
|
pub index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function: Option<FunctionCallDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FunctionCallDelta {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub arguments: Option<String>,
|
||||||
|
}
|
||||||
58
sgl-router/src/protocols/openai/common.rs
Normal file
58
sgl-router/src/protocols/openai/common.rs
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
// Common types shared across OpenAI API implementations
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
// ============= Shared Request Components =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct StreamOptions {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_usage: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Usage Tracking =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionTokensDetails {
|
||||||
|
pub reasoning_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Logprobs Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct LogProbs {
|
||||||
|
pub tokens: Vec<String>,
|
||||||
|
pub token_logprobs: Vec<Option<f32>>,
|
||||||
|
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
||||||
|
pub text_offset: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatLogProbs {
|
||||||
|
pub content: Option<Vec<ChatLogProbsContent>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatLogProbsContent {
|
||||||
|
pub token: String,
|
||||||
|
pub logprob: f32,
|
||||||
|
pub bytes: Option<Vec<u8>>,
|
||||||
|
pub top_logprobs: Vec<TopLogProb>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct TopLogProb {
|
||||||
|
pub token: String,
|
||||||
|
pub logprob: f32,
|
||||||
|
pub bytes: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
10
sgl-router/src/protocols/openai/completions/mod.rs
Normal file
10
sgl-router/src/protocols/openai/completions/mod.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
// Completions API module (v1/completions)
|
||||||
|
|
||||||
|
pub mod request;
|
||||||
|
pub mod response;
|
||||||
|
|
||||||
|
// Re-export main types for convenience
|
||||||
|
pub use request::CompletionRequest;
|
||||||
|
pub use response::{
|
||||||
|
CompletionChoice, CompletionResponse, CompletionStreamChoice, CompletionStreamResponse,
|
||||||
|
};
|
||||||
158
sgl-router/src/protocols/openai/completions/request.rs
Normal file
158
sgl-router/src/protocols/openai/completions/request.rs
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
||||||
|
|
||||||
|
use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray};
|
||||||
|
use crate::protocols::openai::common::StreamOptions;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// The prompt(s) to generate completions for
|
||||||
|
pub prompt: StringOrArray,
|
||||||
|
|
||||||
|
/// The suffix that comes after a completion of inserted text
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature (nucleus sampling)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// How many completions to generate for each prompt
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// Whether to stream back partial progress
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Options for streaming response
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stream_options: Option<StreamOptions>,
|
||||||
|
|
||||||
|
/// Include the log probabilities on the logprobs most likely tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// Echo back the prompt in addition to the completion
|
||||||
|
#[serde(default)]
|
||||||
|
pub echo: bool,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StringOrArray>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Generates best_of completions server-side and returns the "best"
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of: Option<u32>,
|
||||||
|
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<HashMap<String, f32>>,
|
||||||
|
|
||||||
|
/// A unique identifier representing your end-user
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
|
||||||
|
/// If specified, our system will make a best effort to sample deterministically
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Top-k sampling parameter (-1 to disable)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
|
||||||
|
/// Min-p nucleus sampling parameter
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Minimum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Repetition penalty for reducing repetitive text
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Regex constraint for output generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub regex: Option<String>,
|
||||||
|
|
||||||
|
/// EBNF grammar constraint for structured output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ebnf: Option<String>,
|
||||||
|
|
||||||
|
/// JSON schema constraint for structured output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<String>,
|
||||||
|
|
||||||
|
/// Specific token IDs to use as stop conditions
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
|
|
||||||
|
/// Skip trimming stop tokens from output
|
||||||
|
#[serde(default)]
|
||||||
|
pub no_stop_trim: bool,
|
||||||
|
|
||||||
|
/// Ignore end-of-sequence tokens during generation
|
||||||
|
#[serde(default)]
|
||||||
|
pub ignore_eos: bool,
|
||||||
|
|
||||||
|
/// Skip special tokens during detokenization
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub skip_special_tokens: bool,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<LoRAPath>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||||
|
|
||||||
|
/// Return model hidden states
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
|
/// Additional fields including bootstrap info for PD routing
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub other: serde_json::Map<String, serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerationRequest for CompletionRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model(&self) -> Option<&str> {
|
||||||
|
Some(&self.model)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
match &self.prompt {
|
||||||
|
StringOrArray::String(s) => s.clone(),
|
||||||
|
StringOrArray::Array(v) => v.join(" "),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
56
sgl-router/src/protocols/openai/completions/response.rs
Normal file
56
sgl-router/src/protocols/openai/completions/response.rs
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
// Completions API response types
|
||||||
|
|
||||||
|
use crate::protocols::openai::common::{LogProbs, Usage};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// ============= Regular Response =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "text_completion"
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<CompletionChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionChoice {
|
||||||
|
pub text: String,
|
||||||
|
pub index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<LogProbs>,
|
||||||
|
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||||
|
/// Information about which stop condition was matched
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
||||||
|
/// Hidden states from the model (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Streaming Response =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionStreamResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "text_completion"
|
||||||
|
pub created: u64,
|
||||||
|
pub choices: Vec<CompletionStreamChoice>,
|
||||||
|
pub model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionStreamChoice {
|
||||||
|
pub text: String,
|
||||||
|
pub index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<LogProbs>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
19
sgl-router/src/protocols/openai/errors.rs
Normal file
19
sgl-router/src/protocols/openai/errors.rs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
// OpenAI API error response types
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub error: ErrorDetail,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ErrorDetail {
|
||||||
|
pub message: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub error_type: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub param: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub code: Option<String>,
|
||||||
|
}
|
||||||
7
sgl-router/src/protocols/openai/mod.rs
Normal file
7
sgl-router/src/protocols/openai/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
// OpenAI protocol module
|
||||||
|
// This module contains all OpenAI API-compatible types and future validation logic
|
||||||
|
|
||||||
|
pub mod chat;
|
||||||
|
pub mod common;
|
||||||
|
pub mod completions;
|
||||||
|
pub mod errors;
|
||||||
@@ -9,7 +9,10 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::protocols::{
|
||||||
|
generate::GenerateRequest,
|
||||||
|
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
|
||||||
|
};
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod header_utils;
|
pub mod header_utils;
|
||||||
|
|||||||
@@ -11,8 +11,15 @@ use crate::core::{
|
|||||||
RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
RetryExecutor, Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
|
use crate::protocols::{
|
||||||
|
common::StringOrArray,
|
||||||
|
generate::GenerateRequest,
|
||||||
|
openai::{
|
||||||
|
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
||||||
|
completions::CompletionRequest,
|
||||||
|
},
|
||||||
|
};
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
@@ -616,7 +623,7 @@ impl PDRouter {
|
|||||||
// Helper to determine batch size from a GenerateRequest
|
// Helper to determine batch size from a GenerateRequest
|
||||||
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
|
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
|
||||||
// Check prompt array
|
// Check prompt array
|
||||||
if let Some(crate::openai_api_types::StringOrArray::Array(arr)) = &req.prompt {
|
if let Some(StringOrArray::Array(arr)) = &req.prompt {
|
||||||
if !arr.is_empty() {
|
if !arr.is_empty() {
|
||||||
return Some(arr.len());
|
return Some(arr.len());
|
||||||
}
|
}
|
||||||
@@ -645,7 +652,7 @@ impl PDRouter {
|
|||||||
// Helper to determine batch size from a CompletionRequest
|
// Helper to determine batch size from a CompletionRequest
|
||||||
fn get_completion_batch_size(req: &CompletionRequest) -> Option<usize> {
|
fn get_completion_batch_size(req: &CompletionRequest) -> Option<usize> {
|
||||||
// Check prompt array
|
// Check prompt array
|
||||||
if let crate::openai_api_types::StringOrArray::Array(arr) = &req.prompt {
|
if let StringOrArray::Array(arr) = &req.prompt {
|
||||||
if !arr.is_empty() {
|
if !arr.is_empty() {
|
||||||
return Some(arr.len());
|
return Some(arr.len());
|
||||||
}
|
}
|
||||||
@@ -1724,10 +1731,8 @@ impl RouterTrait for PDRouter {
|
|||||||
.as_deref()
|
.as_deref()
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
body.prompt.as_ref().and_then(|p| match p {
|
body.prompt.as_ref().and_then(|p| match p {
|
||||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
|
StringOrArray::String(s) => Some(s.as_str()),
|
||||||
crate::openai_api_types::StringOrArray::Array(v) => {
|
StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
|
||||||
v.first().map(|s| s.as_str())
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -1763,13 +1768,11 @@ impl RouterTrait for PDRouter {
|
|||||||
// Extract text for cache-aware routing
|
// Extract text for cache-aware routing
|
||||||
let request_text = if self.policies_need_request_text() {
|
let request_text = if self.policies_need_request_text() {
|
||||||
body.messages.first().and_then(|msg| match msg {
|
body.messages.first().and_then(|msg| match msg {
|
||||||
crate::openai_api_types::ChatMessage::User { content, .. } => match content {
|
ChatMessage::User { content, .. } => match content {
|
||||||
crate::openai_api_types::UserMessageContent::Text(text) => Some(text.clone()),
|
UserMessageContent::Text(text) => Some(text.clone()),
|
||||||
crate::openai_api_types::UserMessageContent::Parts(_) => None,
|
UserMessageContent::Parts(_) => None,
|
||||||
},
|
},
|
||||||
crate::openai_api_types::ChatMessage::System { content, .. } => {
|
ChatMessage::System { content, .. } => Some(content.clone()),
|
||||||
Some(content.clone())
|
|
||||||
}
|
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
@@ -1804,10 +1807,8 @@ impl RouterTrait for PDRouter {
|
|||||||
// Extract text for cache-aware routing
|
// Extract text for cache-aware routing
|
||||||
let request_text = if self.policies_need_request_text() {
|
let request_text = if self.policies_need_request_text() {
|
||||||
match &body.prompt {
|
match &body.prompt {
|
||||||
crate::openai_api_types::StringOrArray::String(s) => Some(s.clone()),
|
StringOrArray::String(s) => Some(s.clone()),
|
||||||
crate::openai_api_types::StringOrArray::Array(v) => {
|
StringOrArray::Array(v) => v.first().map(|s| s.to_string()),
|
||||||
v.first().map(|s| s.to_string())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|||||||
@@ -8,8 +8,12 @@ use crate::core::{
|
|||||||
RetryExecutor, Worker, WorkerFactory, WorkerType,
|
RetryExecutor, Worker, WorkerFactory, WorkerType,
|
||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
|
use crate::protocols::{
|
||||||
|
common::GenerationRequest,
|
||||||
|
generate::GenerateRequest,
|
||||||
|
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
|
||||||
|
};
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -453,9 +457,7 @@ impl Router {
|
|||||||
Some(available[idx].clone_worker())
|
Some(available[idx].clone_worker())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn route_typed_request<
|
pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
|
||||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
|
||||||
>(
|
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
typed_req: &T,
|
typed_req: &T,
|
||||||
|
|||||||
@@ -1,7 +1,10 @@
|
|||||||
use crate::config::RouterConfig;
|
use crate::config::RouterConfig;
|
||||||
use crate::logging::{self, LoggingConfig};
|
use crate::logging::{self, LoggingConfig};
|
||||||
use crate::metrics::{self, PrometheusConfig};
|
use crate::metrics::{self, PrometheusConfig};
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::protocols::{
|
||||||
|
generate::GenerateRequest,
|
||||||
|
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
|
||||||
|
};
|
||||||
use crate::routers::{RouterFactory, RouterTrait};
|
use crate::routers::{RouterFactory, RouterTrait};
|
||||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|||||||
@@ -5,9 +5,13 @@
|
|||||||
|
|
||||||
use serde_json::{from_str, to_string, to_value};
|
use serde_json::{from_str, to_string, to_value};
|
||||||
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
||||||
use sglang_router_rs::openai_api_types::{
|
use sglang_router_rs::protocols::{
|
||||||
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
common::StringOrArray,
|
||||||
SamplingParams, StringOrArray, UserMessageContent,
|
generate::{GenerateParameters, GenerateRequest, SamplingParams},
|
||||||
|
openai::{
|
||||||
|
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
||||||
|
completions::CompletionRequest,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||||
|
|||||||
Reference in New Issue
Block a user