347 lines
10 KiB
Rust
347 lines
10 KiB
Rust
use std::collections::HashMap;
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
|
|
// ============================================================================
|
|
// Default value helpers
|
|
// ============================================================================
|
|
|
|
/// Default model value when not specified
|
|
pub(crate) fn default_model() -> String {
|
|
"unknown".to_string()
|
|
}
|
|
|
|
/// Helper function for serde default value (returns true)
|
|
pub fn default_true() -> bool {
|
|
true
|
|
}
|
|
|
|
// ============================================================================
|
|
// GenerationRequest Trait
|
|
// ============================================================================
|
|
|
|
/// Trait for unified access to generation request properties
|
|
/// Implemented by ChatCompletionRequest, CompletionRequest, GenerateRequest,
|
|
/// EmbeddingRequest, RerankRequest, and ResponsesRequest
|
|
pub trait GenerationRequest: Send + Sync {
|
|
/// Check if the request is for streaming
|
|
fn is_stream(&self) -> bool;
|
|
|
|
/// Get the model name if specified
|
|
fn get_model(&self) -> Option<&str>;
|
|
|
|
/// Extract text content for routing decisions
|
|
fn extract_text_for_routing(&self) -> String;
|
|
}
|
|
|
|
// ============================================================================
|
|
// String/Array Utilities
|
|
// ============================================================================
|
|
|
|
/// A type that can be either a single string or an array of strings
|
|
#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum StringOrArray {
|
|
String(String),
|
|
Array(Vec<String>),
|
|
}
|
|
|
|
impl StringOrArray {
|
|
/// Get the number of items in the StringOrArray
|
|
pub fn len(&self) -> usize {
|
|
match self {
|
|
StringOrArray::String(_) => 1,
|
|
StringOrArray::Array(arr) => arr.len(),
|
|
}
|
|
}
|
|
|
|
/// Check if the StringOrArray is empty
|
|
pub fn is_empty(&self) -> bool {
|
|
match self {
|
|
StringOrArray::String(s) => s.is_empty(),
|
|
StringOrArray::Array(arr) => arr.is_empty(),
|
|
}
|
|
}
|
|
|
|
/// Convert to a vector of strings
|
|
pub fn to_vec(&self) -> Vec<String> {
|
|
match self {
|
|
StringOrArray::String(s) => vec![s.clone()],
|
|
StringOrArray::Array(arr) => arr.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Content Parts (for multimodal messages)
|
|
// ============================================================================
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum ContentPart {
|
|
#[serde(rename = "text")]
|
|
Text { text: String },
|
|
#[serde(rename = "image_url")]
|
|
ImageUrl { image_url: ImageUrl },
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ImageUrl {
|
|
pub url: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub detail: Option<String>, // "auto", "low", or "high"
|
|
}
|
|
|
|
// ============================================================================
|
|
// Response Format (for structured outputs)
|
|
// ============================================================================
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(tag = "type")]
|
|
pub enum ResponseFormat {
|
|
#[serde(rename = "text")]
|
|
Text,
|
|
#[serde(rename = "json_object")]
|
|
JsonObject,
|
|
#[serde(rename = "json_schema")]
|
|
JsonSchema { json_schema: JsonSchemaFormat },
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct JsonSchemaFormat {
|
|
pub name: String,
|
|
pub schema: Value,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub strict: Option<bool>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Streaming
|
|
// ============================================================================
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct StreamOptions {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub include_usage: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ToolCallDelta {
|
|
pub index: u32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub id: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
#[serde(rename = "type")]
|
|
pub tool_type: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub function: Option<FunctionCallDelta>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct FunctionCallDelta {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub name: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub arguments: Option<String>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tools and Function Calling
|
|
// ============================================================================
|
|
|
|
/// Tool choice value for simple string options
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(rename_all = "snake_case")]
|
|
pub enum ToolChoiceValue {
|
|
Auto,
|
|
Required,
|
|
None,
|
|
}
|
|
|
|
/// Tool choice for both Chat Completion and Responses APIs
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum ToolChoice {
|
|
Value(ToolChoiceValue),
|
|
Function {
|
|
#[serde(rename = "type")]
|
|
tool_type: String, // "function"
|
|
function: FunctionChoice,
|
|
},
|
|
AllowedTools {
|
|
#[serde(rename = "type")]
|
|
tool_type: String, // "allowed_tools"
|
|
mode: String, // "auto" | "required" TODO: need validation
|
|
tools: Vec<ToolReference>,
|
|
},
|
|
}
|
|
|
|
impl Default for ToolChoice {
|
|
fn default() -> Self {
|
|
Self::Value(ToolChoiceValue::Auto)
|
|
}
|
|
}
|
|
|
|
/// Function choice specification for ToolChoice::Function
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct FunctionChoice {
|
|
pub name: String,
|
|
}
|
|
|
|
/// Tool reference for ToolChoice::AllowedTools
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ToolReference {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String, // "function"
|
|
pub name: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct Tool {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String, // "function"
|
|
pub function: Function,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct Function {
|
|
pub name: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub description: Option<String>,
|
|
pub parameters: Value, // JSON Schema
|
|
/// Whether to enable strict schema adherence (OpenAI structured outputs)
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub strict: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ToolCall {
|
|
pub id: String,
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String, // "function"
|
|
pub function: FunctionCallResponse,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum FunctionCall {
|
|
None,
|
|
Auto,
|
|
Function { name: String },
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct FunctionCallResponse {
|
|
pub name: String,
|
|
#[serde(default)]
|
|
pub arguments: Option<String>, // JSON string
|
|
}
|
|
|
|
// ============================================================================
|
|
// Usage and Logging
|
|
// ============================================================================
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct CompletionTokensDetails {
|
|
pub reasoning_tokens: Option<u32>,
|
|
}
|
|
|
|
/// Usage information (used by rerank and other endpoints)
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct UsageInfo {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub reasoning_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct PromptTokenUsageInfo {
|
|
pub cached_tokens: u32,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct LogProbs {
|
|
pub tokens: Vec<String>,
|
|
pub token_logprobs: Vec<Option<f32>>,
|
|
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
|
pub text_offset: Vec<u32>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum ChatLogProbs {
|
|
Detailed {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
content: Option<Vec<ChatLogProbsContent>>,
|
|
},
|
|
Raw(Value),
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ChatLogProbsContent {
|
|
pub token: String,
|
|
pub logprob: f32,
|
|
pub bytes: Option<Vec<u8>>,
|
|
pub top_logprobs: Vec<TopLogProb>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct TopLogProb {
|
|
pub token: String,
|
|
pub logprob: f32,
|
|
pub bytes: Option<Vec<u8>>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Error Types
|
|
// ============================================================================
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ErrorResponse {
|
|
pub error: ErrorDetail,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
pub struct ErrorDetail {
|
|
pub message: String,
|
|
#[serde(rename = "type")]
|
|
pub error_type: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub param: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub code: Option<String>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Input Types
|
|
// ============================================================================
|
|
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum InputIds {
|
|
Single(Vec<i32>),
|
|
Batch(Vec<Vec<i32>>),
|
|
}
|
|
|
|
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
|
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
#[serde(untagged)]
|
|
pub enum LoRAPath {
|
|
Single(Option<String>),
|
|
Batch(Vec<Option<String>>),
|
|
}
|