Files
sglang/sgl-router/src/protocols/common.rs

347 lines
10 KiB
Rust

use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ============================================================================
// Default value helpers
// ============================================================================
/// Default model value when not specified
pub(crate) fn default_model() -> String {
"unknown".to_string()
}
/// Helper function for serde default value (returns true)
pub fn default_true() -> bool {
true
}
// ============================================================================
// GenerationRequest Trait
// ============================================================================
/// Trait for unified access to generation request properties
/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
/// EmbeddingRequest, RerankRequest, and ResponsesRequest
pub trait GenerationRequest: Send + Sync {
/// Check if the request is for streaming
fn is_stream(&self) -> bool;
/// Get the model name if specified
fn get_model(&self) -> Option<&str>;
/// Extract text content for routing decisions
fn extract_text_for_routing(&self) -> String;
}
// ============================================================================
// String/Array Utilities
// ============================================================================
/// A type that can be either a single string or an array of strings
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
impl StringOrArray {
/// Get the number of items in the StringOrArray
pub fn len(&self) -> usize {
match self {
StringOrArray::String(_) => 1,
StringOrArray::Array(arr) => arr.len(),
}
}
/// Check if the StringOrArray is empty
pub fn is_empty(&self) -> bool {
match self {
StringOrArray::String(s) => s.is_empty(),
StringOrArray::Array(arr) => arr.is_empty(),
}
}
/// Convert to a vector of strings
pub fn to_vec(&self) -> Vec<String> {
match self {
StringOrArray::String(s) => vec![s.clone()],
StringOrArray::Array(arr) => arr.clone(),
}
}
}
// ============================================================================
// Content Parts (for multimodal messages)
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, // "auto", "low", or "high"
}
// ============================================================================
// Response Format (for structured outputs)
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
// ============================================================================
// Streaming
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct StreamOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub include_usage: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ============================================================================
// Tools and Function Calling
// ============================================================================
/// Tool choice value for simple string options
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolChoiceValue {
Auto,
Required,
None,
}
/// Tool choice for both Chat Completion and Responses APIs
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
Value(ToolChoiceValue),
Function {
#[serde(rename = "type")]
tool_type: String, // "function"
function: FunctionChoice,
},
AllowedTools {
#[serde(rename = "type")]
tool_type: String, // "allowed_tools"
mode: String, // "auto" | "required" TODO: need validation
tools: Vec<ToolReference>,
},
}
impl Default for ToolChoice {
fn default() -> Self {
Self::Value(ToolChoiceValue::Auto)
}
}
/// Function choice specification for ToolChoice::Function
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice {
pub name: String,
}
/// Tool reference for ToolChoice::AllowedTools
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolReference {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: Function,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, // JSON Schema
/// Whether to enable strict schema adherence (OpenAI structured outputs)
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: FunctionCallResponse,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallResponse {
pub name: String,
#[serde(default)]
pub arguments: Option<String>, // JSON string
}
// ============================================================================
// Usage and Logging
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
}
/// Usage information (used by rerank and other endpoints)
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct PromptTokenUsageInfo {
pub cached_tokens: u32,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChatLogProbs {
Detailed {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<Vec<ChatLogProbsContent>>,
},
Raw(Value),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
// ============================================================================
// Error Types
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
// ============================================================================
// Input Types
// ============================================================================
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum LoRAPath {
Single(Option<String>),
Batch(Vec<Option<String>>),
}