[Router] Refactor protocol definitions: split spec.rs into modular files (#11677)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
345
sgl-router/src/protocols/common.rs
Normal file
345
sgl-router/src/protocols/common.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ============================================================================
|
||||
// Default value helpers
|
||||
// ============================================================================
|
||||
|
||||
/// Default model value when not specified
|
||||
pub(crate) fn default_model() -> String {
|
||||
"unknown".to_string()
|
||||
}
|
||||
|
||||
/// Helper function for serde default value (returns true)
|
||||
pub fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// GenerationRequest Trait
|
||||
// ============================================================================
|
||||
|
||||
/// Trait for unified access to generation request properties
|
||||
/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
|
||||
/// EmbeddingRequest, RerankRequest, and ResponsesRequest
|
||||
pub trait GenerationRequest: Send + Sync {
|
||||
/// Check if the request is for streaming
|
||||
fn is_stream(&self) -> bool;
|
||||
|
||||
/// Get the model name if specified
|
||||
fn get_model(&self) -> Option<&str>;
|
||||
|
||||
/// Extract text content for routing decisions
|
||||
fn extract_text_for_routing(&self) -> String;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// String/Array Utilities
|
||||
// ============================================================================
|
||||
|
||||
/// A type that can be either a single string or an array of strings
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum StringOrArray {
|
||||
String(String),
|
||||
Array(Vec<String>),
|
||||
}
|
||||
|
||||
impl StringOrArray {
|
||||
/// Get the number of items in the StringOrArray
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
StringOrArray::String(_) => 1,
|
||||
StringOrArray::Array(arr) => arr.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the StringOrArray is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
StringOrArray::String(s) => s.is_empty(),
|
||||
StringOrArray::Array(arr) => arr.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to a vector of strings
|
||||
pub fn to_vec(&self) -> Vec<String> {
|
||||
match self {
|
||||
StringOrArray::String(s) => vec![s.clone()],
|
||||
StringOrArray::Array(arr) => arr.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Content Parts (for multimodal messages)
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentPart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>, // "auto", "low", or "high"
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Format (for structured outputs)
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseFormat {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
#[serde(rename = "json_object")]
|
||||
JsonObject,
|
||||
#[serde(rename = "json_schema")]
|
||||
JsonSchema { json_schema: JsonSchemaFormat },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct JsonSchemaFormat {
|
||||
pub name: String,
|
||||
pub schema: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct StreamOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_usage: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolCallDelta {
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionCallDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tools and Function Calling
|
||||
// ============================================================================
|
||||
|
||||
/// Tool choice value for simple string options
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolChoiceValue {
|
||||
Auto,
|
||||
Required,
|
||||
None,
|
||||
}
|
||||
|
||||
/// Tool choice for both Chat Completion and Responses APIs
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
Value(ToolChoiceValue),
|
||||
Function {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // "function"
|
||||
function: FunctionChoice,
|
||||
},
|
||||
AllowedTools {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // "allowed_tools"
|
||||
mode: String, // "auto" | "required" TODO: need validation
|
||||
tools: Vec<ToolReference>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for ToolChoice {
|
||||
fn default() -> Self {
|
||||
Self::Value(ToolChoiceValue::Auto)
|
||||
}
|
||||
}
|
||||
|
||||
/// Function choice specification for ToolChoice::Function
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionChoice {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Tool reference for ToolChoice::AllowedTools
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolReference {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub parameters: Value, // JSON Schema
|
||||
/// Whether to enable strict schema adherence (OpenAI structured outputs)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: FunctionCallResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum FunctionCall {
|
||||
None,
|
||||
Auto,
|
||||
Function { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionCallResponse {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub arguments: Option<String>, // JSON string
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Usage and Logging
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionTokensDetails {
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// Usage information (used by rerank and other endpoints)
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct PromptTokenUsageInfo {
|
||||
pub cached_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct LogProbs {
|
||||
pub tokens: Vec<String>,
|
||||
pub token_logprobs: Vec<Option<f32>>,
|
||||
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
||||
pub text_offset: Vec<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatLogProbs {
|
||||
Detailed {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<Vec<ChatLogProbsContent>>,
|
||||
},
|
||||
Raw(Value),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatLogProbsContent {
|
||||
pub token: String,
|
||||
pub logprob: f32,
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
pub top_logprobs: Vec<TopLogProb>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TopLogProb {
|
||||
pub token: String,
|
||||
pub logprob: f32,
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ErrorDetail {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub param: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Input Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputIds {
|
||||
Single(Vec<i32>),
|
||||
Batch(Vec<Vec<i32>>),
|
||||
}
|
||||
|
||||
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum LoRAPath {
|
||||
Single(Option<String>),
|
||||
Batch(Vec<Option<String>>),
|
||||
}
|
||||
Reference in New Issue
Block a user