[router] Move all protocols to spec.rs file (#9519)
This commit is contained in:
@@ -3,13 +3,9 @@ use serde_json::{from_str, to_string, to_value, to_vec};
|
|||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
||||||
use sglang_router_rs::protocols::{
|
use sglang_router_rs::protocols::spec::{
|
||||||
common::StringOrArray,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
generate::{GenerateParameters, GenerateRequest, SamplingParams},
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
openai::{
|
|
||||||
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
|
||||||
completions::CompletionRequest,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
|
use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap};
|
||||||
|
|
||||||
|
|||||||
@@ -1,61 +0,0 @@
|
|||||||
// Common types shared across all protocol implementations
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
/// Helper function for serde default value
|
|
||||||
pub fn default_true() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Common trait for all generation requests across different APIs
|
|
||||||
pub trait GenerationRequest: Send + Sync {
|
|
||||||
/// Check if the request is for streaming
|
|
||||||
fn is_stream(&self) -> bool;
|
|
||||||
|
|
||||||
/// Get the model name if specified
|
|
||||||
fn get_model(&self) -> Option<&str>;
|
|
||||||
|
|
||||||
/// Extract text content for routing decisions
|
|
||||||
fn extract_text_for_routing(&self) -> String;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper type for string or array of strings
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum StringOrArray {
|
|
||||||
String(String),
|
|
||||||
Array(Vec<String>),
|
|
||||||
}
|
|
||||||
impl StringOrArray {
|
|
||||||
/// Get the number of items in the StringOrArray
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
match self {
|
|
||||||
StringOrArray::String(_) => 1,
|
|
||||||
StringOrArray::Array(arr) => arr.len(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the StringOrArray is empty
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
match self {
|
|
||||||
StringOrArray::String(s) => s.is_empty(),
|
|
||||||
StringOrArray::Array(arr) => arr.is_empty(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert to a vector of strings
|
|
||||||
pub fn to_vec(&self) -> Vec<String> {
|
|
||||||
match self {
|
|
||||||
StringOrArray::String(s) => vec![s.clone()],
|
|
||||||
StringOrArray::Array(arr) => arr.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// LoRA adapter path - can be single path or batch of paths (SGLang extension)
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum LoRAPath {
|
|
||||||
Single(Option<String>),
|
|
||||||
Batch(Vec<Option<String>>),
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
// SGLang native Generate API module (/generate)
|
|
||||||
|
|
||||||
pub mod request;
|
|
||||||
pub mod types;
|
|
||||||
|
|
||||||
// Re-export main types for convenience
|
|
||||||
pub use request::GenerateRequest;
|
|
||||||
pub use types::{GenerateParameters, InputIds, SamplingParams};
|
|
||||||
@@ -1,97 +0,0 @@
|
|||||||
// Generate API request types (/generate)
|
|
||||||
|
|
||||||
use crate::protocols::common::{GenerationRequest, LoRAPath, StringOrArray};
|
|
||||||
use crate::protocols::generate::types::{GenerateParameters, InputIds, SamplingParams};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
||||||
pub struct GenerateRequest {
|
|
||||||
/// The prompt to generate from (OpenAI style)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prompt: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// Text input - SGLang native format
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub text: Option<String>,
|
|
||||||
|
|
||||||
/// Input IDs for tokenized input
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub input_ids: Option<InputIds>,
|
|
||||||
|
|
||||||
/// Generation parameters
|
|
||||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
|
||||||
pub parameters: Option<GenerateParameters>,
|
|
||||||
|
|
||||||
/// Sampling parameters (sglang style)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub sampling_params: Option<SamplingParams>,
|
|
||||||
|
|
||||||
/// Whether to stream the response
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Whether to return logprobs
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_logprob: bool,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_hidden_states: bool,
|
|
||||||
|
|
||||||
/// Request ID for tracking
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub rid: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for GenerateRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
// Generate requests typically don't have a model field
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
// Check fields in priority order: text, prompt, inputs
|
|
||||||
if let Some(ref text) = self.text {
|
|
||||||
return text.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref prompt) = self.prompt {
|
|
||||||
return match prompt {
|
|
||||||
StringOrArray::String(s) => s.clone(),
|
|
||||||
StringOrArray::Array(v) => v.join(" "),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref input_ids) = self.input_ids {
|
|
||||||
return match input_ids {
|
|
||||||
InputIds::Single(ids) => ids
|
|
||||||
.iter()
|
|
||||||
.map(|&id| id.to_string())
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" "),
|
|
||||||
InputIds::Batch(batches) => batches
|
|
||||||
.iter()
|
|
||||||
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" "),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// No text input found
|
|
||||||
String::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
// Types for the SGLang native /generate API
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum InputIds {
|
|
||||||
Single(Vec<i32>),
|
|
||||||
Batch(Vec<Vec<i32>>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
|
||||||
pub struct GenerateParameters {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub best_of: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub decoder_input_details: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub details: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub do_sample: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_new_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub return_full_text: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub seed: Option<u64>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<Vec<String>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub truncate: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub typical_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub watermark: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
|
||||||
pub struct SamplingParams {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_new_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<i32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<crate::protocols::common::StringOrArray>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ignore_eos: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub skip_special_tokens: Option<bool>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub json_schema: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub regex: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ebnf: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_p: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub no_stop_trim: Option<bool>,
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
// Protocol definitions and validation for various LLM APIs
|
// Protocol definitions and validation for various LLM APIs
|
||||||
// This module provides a structured approach to handling different API protocols
|
// This module provides a structured approach to handling different API protocols
|
||||||
|
|
||||||
pub mod common;
|
pub mod spec;
|
||||||
pub mod generate;
|
|
||||||
pub mod openai;
|
|
||||||
pub mod validation;
|
pub mod validation;
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
// Chat Completions API module
|
|
||||||
|
|
||||||
pub mod request;
|
|
||||||
pub mod response;
|
|
||||||
pub mod types;
|
|
||||||
pub mod validation;
|
|
||||||
|
|
||||||
// Re-export main types for convenience
|
|
||||||
pub use request::ChatCompletionRequest;
|
|
||||||
pub use response::{
|
|
||||||
ChatChoice, ChatCompletionResponse, ChatCompletionStreamResponse, ChatStreamChoice,
|
|
||||||
};
|
|
||||||
pub use types::*;
|
|
||||||
@@ -1,216 +0,0 @@
|
|||||||
// Chat Completions API request types
|
|
||||||
|
|
||||||
use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray};
|
|
||||||
use crate::protocols::openai::chat::types::*;
|
|
||||||
use crate::protocols::openai::common::StreamOptions;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionRequest {
|
|
||||||
/// ID of the model to use
|
|
||||||
pub model: String,
|
|
||||||
|
|
||||||
/// A list of messages comprising the conversation so far
|
|
||||||
pub messages: Vec<ChatMessage>,
|
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// An alternative to sampling with temperature
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// How many chat completion choices to generate for each input message
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
|
|
||||||
/// If set, partial message deltas will be sent
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Options for streaming response
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// An upper bound for the number of tokens that can be generated for a completion
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_completion_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logit_bias: Option<HashMap<String, f32>>,
|
|
||||||
|
|
||||||
/// A unique identifier representing your end-user
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub user: Option<String>,
|
|
||||||
|
|
||||||
/// If specified, our system will make a best effort to sample deterministically
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub seed: Option<i64>,
|
|
||||||
|
|
||||||
/// Whether to return log probabilities of the output tokens
|
|
||||||
#[serde(default)]
|
|
||||||
pub logprobs: bool,
|
|
||||||
|
|
||||||
/// An integer between 0 and 20 specifying the number of most likely tokens to return
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_logprobs: Option<u32>,
|
|
||||||
|
|
||||||
/// An object specifying the format that the model must output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub response_format: Option<ResponseFormat>,
|
|
||||||
|
|
||||||
/// A list of tools the model may call
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tools: Option<Vec<Tool>>,
|
|
||||||
|
|
||||||
/// Controls which (if any) tool is called by the model
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_choice: Option<ToolChoice>,
|
|
||||||
|
|
||||||
/// Whether to enable parallel function calling during tool use
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub parallel_tool_calls: Option<bool>,
|
|
||||||
|
|
||||||
/// Deprecated: use tools instead
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub functions: Option<Vec<Function>>,
|
|
||||||
|
|
||||||
/// Deprecated: use tool_choice instead
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function_call: Option<FunctionCall>,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Top-k sampling parameter (-1 to disable)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<i32>,
|
|
||||||
|
|
||||||
/// Min-p nucleus sampling parameter
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_p: Option<f32>,
|
|
||||||
|
|
||||||
/// Minimum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Repetition penalty for reducing repetitive text
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Regex constraint for output generation
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub regex: Option<String>,
|
|
||||||
|
|
||||||
/// EBNF grammar constraint for structured output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ebnf: Option<String>,
|
|
||||||
|
|
||||||
/// Specific token IDs to use as stop conditions
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
|
||||||
|
|
||||||
/// Skip trimming stop tokens from output
|
|
||||||
#[serde(default)]
|
|
||||||
pub no_stop_trim: bool,
|
|
||||||
|
|
||||||
/// Ignore end-of-sequence tokens during generation
|
|
||||||
#[serde(default)]
|
|
||||||
pub ignore_eos: bool,
|
|
||||||
|
|
||||||
/// Continue generating from final assistant message
|
|
||||||
#[serde(default)]
|
|
||||||
pub continue_final_message: bool,
|
|
||||||
|
|
||||||
/// Skip special tokens during detokenization
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub skip_special_tokens: bool,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Separate reasoning content from final answer (O1-style models)
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub separate_reasoning: bool,
|
|
||||||
|
|
||||||
/// Stream reasoning tokens during generation
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub stream_reasoning: bool,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_hidden_states: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for ChatCompletionRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
Some(&self.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
// Extract text from messages for routing decisions
|
|
||||||
self.messages
|
|
||||||
.iter()
|
|
||||||
.filter_map(|msg| match msg {
|
|
||||||
ChatMessage::System { content, .. } => Some(content.clone()),
|
|
||||||
ChatMessage::User { content, .. } => match content {
|
|
||||||
UserMessageContent::Text(text) => Some(text.clone()),
|
|
||||||
UserMessageContent::Parts(parts) => {
|
|
||||||
let texts: Vec<String> = parts
|
|
||||||
.iter()
|
|
||||||
.filter_map(|part| match part {
|
|
||||||
ContentPart::Text { text } => Some(text.clone()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
Some(texts.join(" "))
|
|
||||||
}
|
|
||||||
},
|
|
||||||
ChatMessage::Assistant {
|
|
||||||
content,
|
|
||||||
reasoning_content,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
// Combine content and reasoning content for routing decisions
|
|
||||||
let main_content = content.clone().unwrap_or_default();
|
|
||||||
let reasoning = reasoning_content.clone().unwrap_or_default();
|
|
||||||
if main_content.is_empty() && reasoning.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(format!("{} {}", main_content, reasoning).trim().to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
|
||||||
ChatMessage::Function { content, .. } => Some(content.clone()),
|
|
||||||
})
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" ")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,59 +0,0 @@
|
|||||||
// Chat Completions API response types
|
|
||||||
|
|
||||||
use crate::protocols::openai::chat::types::{ChatMessage, ChatMessageDelta};
|
|
||||||
use crate::protocols::openai::common::{ChatLogProbs, Usage};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
// ============= Regular Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "chat.completion"
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<ChatChoice>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub message: ChatMessage,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
|
||||||
/// Information about which stop condition was matched
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
|
||||||
/// Hidden states from the model (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub hidden_states: Option<Vec<f32>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Streaming Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatCompletionStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "chat.completion.chunk"
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
pub choices: Vec<ChatStreamChoice>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatStreamChoice {
|
|
||||||
pub index: u32,
|
|
||||||
pub delta: ChatMessageDelta,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
// Types specific to the Chat Completions API
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use serde_json::Value;
|
|
||||||
|
|
||||||
// ============= Message Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ChatMessage {
|
|
||||||
System {
|
|
||||||
role: String, // "system"
|
|
||||||
content: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
name: Option<String>,
|
|
||||||
},
|
|
||||||
User {
|
|
||||||
role: String, // "user"
|
|
||||||
content: UserMessageContent,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
name: Option<String>,
|
|
||||||
},
|
|
||||||
Assistant {
|
|
||||||
role: String, // "assistant"
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tool_calls: Option<Vec<ToolCall>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
function_call: Option<FunctionCallResponse>,
|
|
||||||
/// Reasoning content for O1-style models (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
reasoning_content: Option<String>,
|
|
||||||
},
|
|
||||||
Tool {
|
|
||||||
role: String, // "tool"
|
|
||||||
content: String,
|
|
||||||
tool_call_id: String,
|
|
||||||
},
|
|
||||||
Function {
|
|
||||||
role: String, // "function"
|
|
||||||
content: String,
|
|
||||||
name: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum UserMessageContent {
|
|
||||||
Text(String),
|
|
||||||
Parts(Vec<ContentPart>),
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum ContentPart {
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text { text: String },
|
|
||||||
#[serde(rename = "image_url")]
|
|
||||||
ImageUrl { image_url: ImageUrl },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ImageUrl {
|
|
||||||
pub url: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub detail: Option<String>, // "auto", "low", or "high"
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Response Format Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
pub enum ResponseFormat {
|
|
||||||
#[serde(rename = "text")]
|
|
||||||
Text,
|
|
||||||
#[serde(rename = "json_object")]
|
|
||||||
JsonObject,
|
|
||||||
#[serde(rename = "json_schema")]
|
|
||||||
JsonSchema { json_schema: JsonSchemaFormat },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct JsonSchemaFormat {
|
|
||||||
pub name: String,
|
|
||||||
pub schema: Value,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub strict: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Tool/Function Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Tool {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: String, // "function"
|
|
||||||
pub function: Function,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Function {
|
|
||||||
pub name: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub parameters: Value, // JSON Schema
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ToolChoice {
|
|
||||||
None,
|
|
||||||
Auto,
|
|
||||||
Required,
|
|
||||||
Function {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
tool_type: String, // "function"
|
|
||||||
function: FunctionChoice,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FunctionChoice {
|
|
||||||
pub name: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ToolCall {
|
|
||||||
pub id: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: String, // "function"
|
|
||||||
pub function: FunctionCallResponse,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum FunctionCall {
|
|
||||||
None,
|
|
||||||
Auto,
|
|
||||||
Function { name: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FunctionCallResponse {
|
|
||||||
pub name: String,
|
|
||||||
pub arguments: String, // JSON string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Streaming Delta Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatMessageDelta {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub role: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub content: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function_call: Option<FunctionCallDelta>,
|
|
||||||
/// Reasoning content delta for O1-style models (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub reasoning_content: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ToolCallDelta {
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub id: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub tool_type: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub function: Option<FunctionCallDelta>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct FunctionCallDelta {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub name: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub arguments: Option<String>,
|
|
||||||
}
|
|
||||||
@@ -1,477 +0,0 @@
|
|||||||
// Validation implementation for Chat Completions API
|
|
||||||
|
|
||||||
use crate::protocols::common::StringOrArray;
|
|
||||||
use crate::protocols::openai::chat::request::ChatCompletionRequest;
|
|
||||||
use crate::protocols::openai::chat::types::{ChatMessage, ResponseFormat, UserMessageContent};
|
|
||||||
use crate::protocols::validation::{
|
|
||||||
utils::{
|
|
||||||
validate_common_request_params, validate_conflicting_parameters,
|
|
||||||
validate_mutually_exclusive_options, validate_non_empty_array,
|
|
||||||
},
|
|
||||||
CompletionCountProvider, LogProbsProvider, SGLangExtensionsProvider, SamplingOptionsProvider,
|
|
||||||
StopConditionsProvider, TokenLimitsProvider, ValidatableRequest, ValidationError,
|
|
||||||
};
|
|
||||||
|
|
||||||
impl SamplingOptionsProvider for ChatCompletionRequest {
|
|
||||||
fn get_temperature(&self) -> Option<f32> {
|
|
||||||
self.temperature
|
|
||||||
}
|
|
||||||
fn get_top_p(&self) -> Option<f32> {
|
|
||||||
self.top_p
|
|
||||||
}
|
|
||||||
fn get_frequency_penalty(&self) -> Option<f32> {
|
|
||||||
self.frequency_penalty
|
|
||||||
}
|
|
||||||
fn get_presence_penalty(&self) -> Option<f32> {
|
|
||||||
self.presence_penalty
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StopConditionsProvider for ChatCompletionRequest {
|
|
||||||
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
|
|
||||||
self.stop.as_ref()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TokenLimitsProvider for ChatCompletionRequest {
|
|
||||||
fn get_max_tokens(&self) -> Option<u32> {
|
|
||||||
// Prefer max_completion_tokens over max_tokens if both are set
|
|
||||||
self.max_completion_tokens.or(self.max_tokens)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_min_tokens(&self) -> Option<u32> {
|
|
||||||
self.min_tokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LogProbsProvider for ChatCompletionRequest {
|
|
||||||
fn get_logprobs(&self) -> Option<u32> {
|
|
||||||
// For chat API, logprobs is a boolean, return 1 if true for validation purposes
|
|
||||||
if self.logprobs {
|
|
||||||
Some(1)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_top_logprobs(&self) -> Option<u32> {
|
|
||||||
self.top_logprobs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SGLangExtensionsProvider for ChatCompletionRequest {
|
|
||||||
fn get_top_k(&self) -> Option<i32> {
|
|
||||||
self.top_k
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_min_p(&self) -> Option<f32> {
|
|
||||||
self.min_p
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_repetition_penalty(&self) -> Option<f32> {
|
|
||||||
self.repetition_penalty
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionCountProvider for ChatCompletionRequest {
|
|
||||||
fn get_n(&self) -> Option<u32> {
|
|
||||||
self.n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ChatCompletionRequest {
|
|
||||||
/// Validate message-specific requirements
|
|
||||||
pub fn validate_messages(&self) -> Result<(), ValidationError> {
|
|
||||||
// Ensure messages array is not empty
|
|
||||||
validate_non_empty_array(&self.messages, "messages")?;
|
|
||||||
|
|
||||||
// Validate message content is not empty
|
|
||||||
for (i, msg) in self.messages.iter().enumerate() {
|
|
||||||
if let ChatMessage::User { content, .. } = msg {
|
|
||||||
match content {
|
|
||||||
UserMessageContent::Text(text) if text.is_empty() => {
|
|
||||||
return Err(ValidationError::InvalidValue {
|
|
||||||
parameter: format!("messages[{}].content", i),
|
|
||||||
value: "empty".to_string(),
|
|
||||||
reason: "message content cannot be empty".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
UserMessageContent::Parts(parts) if parts.is_empty() => {
|
|
||||||
return Err(ValidationError::InvalidValue {
|
|
||||||
parameter: format!("messages[{}].content", i),
|
|
||||||
value: "empty array".to_string(),
|
|
||||||
reason: "message content parts cannot be empty".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
_ => {}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate response format if specified
|
|
||||||
pub fn validate_response_format(&self) -> Result<(), ValidationError> {
|
|
||||||
if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format {
|
|
||||||
if json_schema.name.is_empty() {
|
|
||||||
return Err(ValidationError::InvalidValue {
|
|
||||||
parameter: "response_format.json_schema.name".to_string(),
|
|
||||||
value: "empty".to_string(),
|
|
||||||
reason: "JSON schema name cannot be empty".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate chat API specific logprobs requirements
|
|
||||||
pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> {
|
|
||||||
// In chat API, if logprobs=true, top_logprobs must be specified
|
|
||||||
if self.logprobs && self.top_logprobs.is_none() {
|
|
||||||
return Err(ValidationError::MissingRequired {
|
|
||||||
parameter: "top_logprobs".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// If top_logprobs is specified, logprobs should be true
|
|
||||||
if self.top_logprobs.is_some() && !self.logprobs {
|
|
||||||
return Err(ValidationError::InvalidValue {
|
|
||||||
parameter: "logprobs".to_string(),
|
|
||||||
value: "false".to_string(),
|
|
||||||
reason: "must be true when top_logprobs is specified".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate cross-parameter relationships specific to chat completions
|
|
||||||
pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> {
|
|
||||||
// Validate that both max_tokens and max_completion_tokens aren't set
|
|
||||||
validate_conflicting_parameters(
|
|
||||||
"max_tokens",
|
|
||||||
self.max_tokens.is_some(),
|
|
||||||
"max_completion_tokens",
|
|
||||||
self.max_completion_tokens.is_some(),
|
|
||||||
"cannot specify both max_tokens and max_completion_tokens",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Validate that tools and functions aren't both specified (deprecated)
|
|
||||||
validate_conflicting_parameters(
|
|
||||||
"tools",
|
|
||||||
self.tools.is_some(),
|
|
||||||
"functions",
|
|
||||||
self.functions.is_some(),
|
|
||||||
"functions is deprecated, use tools instead",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Validate structured output constraints don't conflict with JSON response format
|
|
||||||
let has_json_format = matches!(
|
|
||||||
self.response_format,
|
|
||||||
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
|
|
||||||
);
|
|
||||||
|
|
||||||
validate_conflicting_parameters(
|
|
||||||
"response_format",
|
|
||||||
has_json_format,
|
|
||||||
"regex",
|
|
||||||
self.regex.is_some(),
|
|
||||||
"cannot use regex constraint with JSON response format",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
validate_conflicting_parameters(
|
|
||||||
"response_format",
|
|
||||||
has_json_format,
|
|
||||||
"ebnf",
|
|
||||||
self.ebnf.is_some(),
|
|
||||||
"cannot use EBNF constraint with JSON response format",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Only one structured output constraint should be active
|
|
||||||
let structured_constraints = [
|
|
||||||
("regex", self.regex.is_some()),
|
|
||||||
("ebnf", self.ebnf.is_some()),
|
|
||||||
(
|
|
||||||
"json_schema",
|
|
||||||
matches!(
|
|
||||||
self.response_format,
|
|
||||||
Some(ResponseFormat::JsonSchema { .. })
|
|
||||||
),
|
|
||||||
),
|
|
||||||
];
|
|
||||||
|
|
||||||
validate_mutually_exclusive_options(
|
|
||||||
&structured_constraints,
|
|
||||||
"Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ValidatableRequest for ChatCompletionRequest {
|
|
||||||
fn validate(&self) -> Result<(), ValidationError> {
|
|
||||||
// Call the common validation function from the validation module
|
|
||||||
validate_common_request_params(self)?;
|
|
||||||
|
|
||||||
// Then validate chat-specific parameters
|
|
||||||
self.validate_messages()?;
|
|
||||||
self.validate_response_format()?;
|
|
||||||
self.validate_chat_logprobs()?;
|
|
||||||
self.validate_chat_cross_parameters()?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::protocols::openai::chat::types::*;
|
|
||||||
|
|
||||||
fn create_valid_request() -> ChatCompletionRequest {
|
|
||||||
ChatCompletionRequest {
|
|
||||||
model: "gpt-4".to_string(),
|
|
||||||
messages: vec![ChatMessage::User {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: UserMessageContent::Text("Hello".to_string()),
|
|
||||||
name: None,
|
|
||||||
}],
|
|
||||||
temperature: Some(1.0),
|
|
||||||
top_p: Some(0.9),
|
|
||||||
n: Some(1),
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
max_tokens: Some(100),
|
|
||||||
max_completion_tokens: None,
|
|
||||||
presence_penalty: Some(0.0),
|
|
||||||
frequency_penalty: Some(0.0),
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
response_format: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: None,
|
|
||||||
functions: None,
|
|
||||||
function_call: None,
|
|
||||||
// SGLang extensions
|
|
||||||
top_k: None,
|
|
||||||
min_p: None,
|
|
||||||
min_tokens: None,
|
|
||||||
repetition_penalty: None,
|
|
||||||
regex: None,
|
|
||||||
ebnf: None,
|
|
||||||
stop_token_ids: None,
|
|
||||||
no_stop_trim: false,
|
|
||||||
ignore_eos: false,
|
|
||||||
continue_final_message: false,
|
|
||||||
skip_special_tokens: true,
|
|
||||||
lora_path: None,
|
|
||||||
session_params: None,
|
|
||||||
separate_reasoning: true,
|
|
||||||
stream_reasoning: true,
|
|
||||||
return_hidden_states: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_valid_chat_request() {
|
|
||||||
let request = create_valid_request();
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_invalid_temperature() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.temperature = Some(3.0); // Too high
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::OutOfRange { parameter, .. } => {
|
|
||||||
assert_eq!(parameter, "temperature");
|
|
||||||
}
|
|
||||||
_ => panic!("Expected OutOfRange error"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_invalid_top_p() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.top_p = Some(1.5); // Too high
|
|
||||||
|
|
||||||
assert!(request.validate().is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_too_many_stop_sequences() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.stop = Some(StringOrArray::Array(vec![
|
|
||||||
"stop1".to_string(),
|
|
||||||
"stop2".to_string(),
|
|
||||||
"stop3".to_string(),
|
|
||||||
"stop4".to_string(),
|
|
||||||
"stop5".to_string(), // Too many
|
|
||||||
]));
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_empty_stop_sequence() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.stop = Some(StringOrArray::String("".to_string()));
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::InvalidValue {
|
|
||||||
parameter, reason, ..
|
|
||||||
} => {
|
|
||||||
assert_eq!(parameter, "stop");
|
|
||||||
assert!(reason.contains("empty"));
|
|
||||||
}
|
|
||||||
_ => panic!("Expected InvalidValue error"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_empty_messages() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.messages = vec![];
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::MissingRequired { parameter } => {
|
|
||||||
assert_eq!(parameter, "messages");
|
|
||||||
}
|
|
||||||
_ => panic!("Expected MissingRequired error"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_invalid_n_parameter() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.n = Some(0);
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
|
|
||||||
request.n = Some(20); // Too high
|
|
||||||
assert!(request.validate().is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_conflicting_max_tokens() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.max_tokens = Some(100);
|
|
||||||
request.max_completion_tokens = Some(200);
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::ConflictingParameters {
|
|
||||||
parameter1,
|
|
||||||
parameter2,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
assert!(parameter1.contains("max_tokens"));
|
|
||||||
assert!(parameter2.contains("max_completion_tokens"));
|
|
||||||
}
|
|
||||||
_ => panic!("Expected ConflictingParameters error"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_logprobs_without_top_logprobs() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.logprobs = true;
|
|
||||||
request.top_logprobs = None;
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_sglang_extensions() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
|
|
||||||
// Valid top_k
|
|
||||||
request.top_k = Some(-1); // Disabled
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
|
|
||||||
request.top_k = Some(50); // Valid positive
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
|
|
||||||
request.top_k = Some(0); // Invalid
|
|
||||||
assert!(request.validate().is_err());
|
|
||||||
|
|
||||||
// Valid min_p
|
|
||||||
request.top_k = None;
|
|
||||||
request.min_p = Some(0.1);
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
|
|
||||||
request.min_p = Some(1.5); // Too high
|
|
||||||
assert!(request.validate().is_err());
|
|
||||||
|
|
||||||
// Valid repetition_penalty
|
|
||||||
request.min_p = None;
|
|
||||||
request.repetition_penalty = Some(1.2);
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
|
|
||||||
request.repetition_penalty = Some(0.0); // Valid - minimum value
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
|
|
||||||
request.repetition_penalty = Some(2.0); // Valid - maximum value
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
|
|
||||||
request.repetition_penalty = Some(2.1); // Invalid - too high
|
|
||||||
assert!(request.validate().is_err());
|
|
||||||
|
|
||||||
request.repetition_penalty = Some(-0.1); // Invalid - negative
|
|
||||||
assert!(request.validate().is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_structured_output_conflicts() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
|
|
||||||
// JSON response format with regex should conflict
|
|
||||||
request.response_format = Some(ResponseFormat::JsonObject);
|
|
||||||
request.regex = Some(".*".to_string());
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
|
|
||||||
// Multiple structured constraints should conflict
|
|
||||||
request.response_format = None;
|
|
||||||
request.regex = Some(".*".to_string());
|
|
||||||
request.ebnf = Some("grammar".to_string());
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_min_max_tokens_validation() {
|
|
||||||
let mut request = create_valid_request();
|
|
||||||
request.min_tokens = Some(100);
|
|
||||||
request.max_tokens = Some(50); // min > max
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
|
|
||||||
// Should work with max_completion_tokens too
|
|
||||||
request.max_tokens = None;
|
|
||||||
request.max_completion_tokens = Some(200);
|
|
||||||
request.min_tokens = Some(100);
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
// Common types shared across OpenAI API implementations
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
// ============= Shared Request Components =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct StreamOptions {
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub include_usage: Option<bool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Usage Tracking =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct Usage {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionTokensDetails {
|
|
||||||
pub reasoning_tokens: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Logprobs Types =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct LogProbs {
|
|
||||||
pub tokens: Vec<String>,
|
|
||||||
pub token_logprobs: Vec<Option<f32>>,
|
|
||||||
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
|
||||||
pub text_offset: Vec<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatLogProbs {
|
|
||||||
pub content: Option<Vec<ChatLogProbsContent>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ChatLogProbsContent {
|
|
||||||
pub token: String,
|
|
||||||
pub logprob: f32,
|
|
||||||
pub bytes: Option<Vec<u8>>,
|
|
||||||
pub top_logprobs: Vec<TopLogProb>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct TopLogProb {
|
|
||||||
pub token: String,
|
|
||||||
pub logprob: f32,
|
|
||||||
pub bytes: Option<Vec<u8>>,
|
|
||||||
}
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
// Completions API module (v1/completions)
|
|
||||||
|
|
||||||
pub mod request;
|
|
||||||
pub mod response;
|
|
||||||
|
|
||||||
// Re-export main types for convenience
|
|
||||||
pub use request::CompletionRequest;
|
|
||||||
pub use response::{
|
|
||||||
CompletionChoice, CompletionResponse, CompletionStreamChoice, CompletionStreamResponse,
|
|
||||||
};
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
// Completions API request types (v1/completions) - DEPRECATED but still supported
|
|
||||||
|
|
||||||
use crate::protocols::common::{default_true, GenerationRequest, LoRAPath, StringOrArray};
|
|
||||||
use crate::protocols::openai::common::StreamOptions;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionRequest {
|
|
||||||
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
|
|
||||||
pub model: String,
|
|
||||||
|
|
||||||
/// The prompt(s) to generate completions for
|
|
||||||
pub prompt: StringOrArray,
|
|
||||||
|
|
||||||
/// The suffix that comes after a completion of inserted text
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub suffix: Option<String>,
|
|
||||||
|
|
||||||
/// The maximum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// What sampling temperature to use, between 0 and 2
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// An alternative to sampling with temperature (nucleus sampling)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// How many completions to generate for each prompt
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub n: Option<u32>,
|
|
||||||
|
|
||||||
/// Whether to stream back partial progress
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Options for streaming response
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stream_options: Option<StreamOptions>,
|
|
||||||
|
|
||||||
/// Include the log probabilities on the logprobs most likely tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<u32>,
|
|
||||||
|
|
||||||
/// Echo back the prompt in addition to the completion
|
|
||||||
#[serde(default)]
|
|
||||||
pub echo: bool,
|
|
||||||
|
|
||||||
/// Up to 4 sequences where the API will stop generating further tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub presence_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub frequency_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Generates best_of completions server-side and returns the "best"
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub best_of: Option<u32>,
|
|
||||||
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logit_bias: Option<HashMap<String, f32>>,
|
|
||||||
|
|
||||||
/// A unique identifier representing your end-user
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub user: Option<String>,
|
|
||||||
|
|
||||||
/// If specified, our system will make a best effort to sample deterministically
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub seed: Option<i64>,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Top-k sampling parameter (-1 to disable)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_k: Option<i32>,
|
|
||||||
|
|
||||||
/// Min-p nucleus sampling parameter
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_p: Option<f32>,
|
|
||||||
|
|
||||||
/// Minimum number of tokens to generate
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub min_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Repetition penalty for reducing repetitive text
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub repetition_penalty: Option<f32>,
|
|
||||||
|
|
||||||
/// Regex constraint for output generation
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub regex: Option<String>,
|
|
||||||
|
|
||||||
/// EBNF grammar constraint for structured output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub ebnf: Option<String>,
|
|
||||||
|
|
||||||
/// JSON schema constraint for structured output
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub json_schema: Option<String>,
|
|
||||||
|
|
||||||
/// Specific token IDs to use as stop conditions
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop_token_ids: Option<Vec<i32>>,
|
|
||||||
|
|
||||||
/// Skip trimming stop tokens from output
|
|
||||||
#[serde(default)]
|
|
||||||
pub no_stop_trim: bool,
|
|
||||||
|
|
||||||
/// Ignore end-of-sequence tokens during generation
|
|
||||||
#[serde(default)]
|
|
||||||
pub ignore_eos: bool,
|
|
||||||
|
|
||||||
/// Skip special tokens during detokenization
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub skip_special_tokens: bool,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Path to LoRA adapter(s) for model customization
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub lora_path: Option<LoRAPath>,
|
|
||||||
|
|
||||||
/// Session parameters for continual prompting
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Return model hidden states
|
|
||||||
#[serde(default)]
|
|
||||||
pub return_hidden_states: bool,
|
|
||||||
|
|
||||||
/// Additional fields including bootstrap info for PD routing
|
|
||||||
#[serde(flatten)]
|
|
||||||
pub other: serde_json::Map<String, serde_json::Value>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for CompletionRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
Some(&self.model)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
match &self.prompt {
|
|
||||||
StringOrArray::String(s) => s.clone(),
|
|
||||||
StringOrArray::Array(v) => v.join(" "),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
// Completions API response types
|
|
||||||
|
|
||||||
use crate::protocols::openai::common::{LogProbs, Usage};
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
// ============= Regular Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "text_completion"
|
|
||||||
pub created: u64,
|
|
||||||
pub model: String,
|
|
||||||
pub choices: Vec<CompletionChoice>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<Usage>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionChoice {
|
|
||||||
pub text: String,
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<LogProbs>,
|
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
|
||||||
/// Information about which stop condition was matched
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
|
||||||
/// Hidden states from the model (SGLang extension)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub hidden_states: Option<Vec<f32>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Streaming Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionStreamResponse {
|
|
||||||
pub id: String,
|
|
||||||
pub object: String, // "text_completion"
|
|
||||||
pub created: u64,
|
|
||||||
pub choices: Vec<CompletionStreamChoice>,
|
|
||||||
pub model: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub system_fingerprint: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct CompletionStreamChoice {
|
|
||||||
pub text: String,
|
|
||||||
pub index: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub logprobs: Option<LogProbs>,
|
|
||||||
pub finish_reason: Option<String>,
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
// OpenAI API error response types
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ErrorResponse {
|
|
||||||
pub error: ErrorDetail,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ErrorDetail {
|
|
||||||
pub message: String,
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub error_type: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub param: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub code: Option<String>,
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
// OpenAI protocol module
|
|
||||||
// This module contains all OpenAI API-compatible types and future validation logic
|
|
||||||
|
|
||||||
pub mod chat;
|
|
||||||
pub mod common;
|
|
||||||
pub mod completions;
|
|
||||||
pub mod errors;
|
|
||||||
pub mod responses;
|
|
||||||
@@ -1,10 +0,0 @@
|
|||||||
// Responses API module
|
|
||||||
|
|
||||||
pub mod request;
|
|
||||||
pub mod response;
|
|
||||||
pub mod types;
|
|
||||||
|
|
||||||
// Re-export main types for convenience
|
|
||||||
pub use request::ResponsesRequest;
|
|
||||||
pub use response::ResponsesResponse;
|
|
||||||
pub use types::*;
|
|
||||||
@@ -1,300 +0,0 @@
|
|||||||
// Responses API request types
|
|
||||||
|
|
||||||
use crate::protocols::common::{GenerationRequest, StringOrArray};
|
|
||||||
use crate::protocols::openai::responses::types::*;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
fn generate_request_id() -> String {
|
|
||||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ResponsesRequest {
|
|
||||||
// ============= Core OpenAI API fields =============
|
|
||||||
/// Run the request in the background
|
|
||||||
#[serde(default)]
|
|
||||||
pub background: bool,
|
|
||||||
|
|
||||||
/// Fields to include in the response
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub include: Option<Vec<IncludeField>>,
|
|
||||||
|
|
||||||
/// Input content - can be string or structured items
|
|
||||||
pub input: ResponseInput,
|
|
||||||
|
|
||||||
/// System instructions for the model
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub instructions: Option<String>,
|
|
||||||
|
|
||||||
/// Maximum number of output tokens
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_output_tokens: Option<u32>,
|
|
||||||
|
|
||||||
/// Maximum number of tool calls
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub max_tool_calls: Option<u32>,
|
|
||||||
|
|
||||||
/// Additional metadata
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub metadata: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
|
|
||||||
/// Model to use (optional to match vLLM)
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub model: Option<String>,
|
|
||||||
|
|
||||||
/// Whether to enable parallel tool calls
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub parallel_tool_calls: bool,
|
|
||||||
|
|
||||||
/// ID of previous response to continue from
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub previous_response_id: Option<String>,
|
|
||||||
|
|
||||||
/// Reasoning configuration
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub reasoning: Option<ResponseReasoningParam>,
|
|
||||||
|
|
||||||
/// Service tier
|
|
||||||
#[serde(default)]
|
|
||||||
pub service_tier: ServiceTier,
|
|
||||||
|
|
||||||
/// Whether to store the response
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub store: bool,
|
|
||||||
|
|
||||||
/// Whether to stream the response
|
|
||||||
#[serde(default)]
|
|
||||||
pub stream: bool,
|
|
||||||
|
|
||||||
/// Temperature for sampling
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub temperature: Option<f32>,
|
|
||||||
|
|
||||||
/// Tool choice behavior
|
|
||||||
#[serde(default)]
|
|
||||||
pub tool_choice: ToolChoice,
|
|
||||||
|
|
||||||
/// Available tools
|
|
||||||
#[serde(default)]
|
|
||||||
pub tools: Vec<ResponseTool>,
|
|
||||||
|
|
||||||
/// Number of top logprobs to return
|
|
||||||
#[serde(default)]
|
|
||||||
pub top_logprobs: u32,
|
|
||||||
|
|
||||||
/// Top-p sampling parameter
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
|
|
||||||
/// Truncation behavior
|
|
||||||
#[serde(default)]
|
|
||||||
pub truncation: Truncation,
|
|
||||||
|
|
||||||
/// User identifier
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub user: Option<String>,
|
|
||||||
|
|
||||||
// ============= SGLang Extensions =============
|
|
||||||
/// Request ID
|
|
||||||
#[serde(default = "generate_request_id")]
|
|
||||||
pub request_id: String,
|
|
||||||
|
|
||||||
/// Request priority
|
|
||||||
#[serde(default)]
|
|
||||||
pub priority: i32,
|
|
||||||
|
|
||||||
/// Frequency penalty
|
|
||||||
#[serde(default)]
|
|
||||||
pub frequency_penalty: f32,
|
|
||||||
|
|
||||||
/// Presence penalty
|
|
||||||
#[serde(default)]
|
|
||||||
pub presence_penalty: f32,
|
|
||||||
|
|
||||||
/// Stop sequences
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub stop: Option<StringOrArray>,
|
|
||||||
|
|
||||||
/// Top-k sampling parameter
|
|
||||||
#[serde(default = "default_top_k")]
|
|
||||||
pub top_k: i32,
|
|
||||||
|
|
||||||
/// Min-p sampling parameter
|
|
||||||
#[serde(default)]
|
|
||||||
pub min_p: f32,
|
|
||||||
|
|
||||||
/// Repetition penalty
|
|
||||||
#[serde(default = "default_repetition_penalty")]
|
|
||||||
pub repetition_penalty: f32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(untagged)]
|
|
||||||
pub enum ResponseInput {
|
|
||||||
Text(String),
|
|
||||||
Items(Vec<ResponseInputOutputItem>),
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_top_k() -> i32 {
|
|
||||||
-1
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_repetition_penalty() -> f32 {
|
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_true() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponsesRequest {
|
|
||||||
/// Default sampling parameters
|
|
||||||
const DEFAULT_TEMPERATURE: f32 = 0.7;
|
|
||||||
const DEFAULT_TOP_P: f32 = 1.0;
|
|
||||||
|
|
||||||
/// Convert to sampling parameters for generation
|
|
||||||
pub fn to_sampling_params(
|
|
||||||
&self,
|
|
||||||
default_max_tokens: u32,
|
|
||||||
default_params: Option<HashMap<String, serde_json::Value>>,
|
|
||||||
) -> HashMap<String, serde_json::Value> {
|
|
||||||
let mut params = HashMap::new();
|
|
||||||
|
|
||||||
// Use max_output_tokens if available
|
|
||||||
let max_tokens = if let Some(max_output) = self.max_output_tokens {
|
|
||||||
std::cmp::min(max_output, default_max_tokens)
|
|
||||||
} else {
|
|
||||||
default_max_tokens
|
|
||||||
};
|
|
||||||
|
|
||||||
// Avoid exceeding context length by minus 1 token
|
|
||||||
let max_tokens = max_tokens.saturating_sub(1);
|
|
||||||
|
|
||||||
// Temperature
|
|
||||||
let temperature = self.temperature.unwrap_or_else(|| {
|
|
||||||
default_params
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|p| p.get("temperature"))
|
|
||||||
.and_then(|v| v.as_f64())
|
|
||||||
.map(|v| v as f32)
|
|
||||||
.unwrap_or(Self::DEFAULT_TEMPERATURE)
|
|
||||||
});
|
|
||||||
|
|
||||||
// Top-p
|
|
||||||
let top_p = self.top_p.unwrap_or_else(|| {
|
|
||||||
default_params
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|p| p.get("top_p"))
|
|
||||||
.and_then(|v| v.as_f64())
|
|
||||||
.map(|v| v as f32)
|
|
||||||
.unwrap_or(Self::DEFAULT_TOP_P)
|
|
||||||
});
|
|
||||||
|
|
||||||
params.insert(
|
|
||||||
"max_new_tokens".to_string(),
|
|
||||||
serde_json::Value::Number(serde_json::Number::from(max_tokens)),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"temperature".to_string(),
|
|
||||||
serde_json::Value::Number(serde_json::Number::from_f64(temperature as f64).unwrap()),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"top_p".to_string(),
|
|
||||||
serde_json::Value::Number(serde_json::Number::from_f64(top_p as f64).unwrap()),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"frequency_penalty".to_string(),
|
|
||||||
serde_json::Value::Number(
|
|
||||||
serde_json::Number::from_f64(self.frequency_penalty as f64).unwrap(),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"presence_penalty".to_string(),
|
|
||||||
serde_json::Value::Number(
|
|
||||||
serde_json::Number::from_f64(self.presence_penalty as f64).unwrap(),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"top_k".to_string(),
|
|
||||||
serde_json::Value::Number(serde_json::Number::from(self.top_k)),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"min_p".to_string(),
|
|
||||||
serde_json::Value::Number(serde_json::Number::from_f64(self.min_p as f64).unwrap()),
|
|
||||||
);
|
|
||||||
params.insert(
|
|
||||||
"repetition_penalty".to_string(),
|
|
||||||
serde_json::Value::Number(
|
|
||||||
serde_json::Number::from_f64(self.repetition_penalty as f64).unwrap(),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
|
|
||||||
if let Some(ref stop) = self.stop {
|
|
||||||
match serde_json::to_value(stop) {
|
|
||||||
Ok(value) => params.insert("stop".to_string(), value),
|
|
||||||
Err(_) => params.insert("stop".to_string(), serde_json::Value::Null),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply any additional default parameters
|
|
||||||
if let Some(default_params) = default_params {
|
|
||||||
for (key, value) in default_params {
|
|
||||||
params.entry(key).or_insert(value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
params
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GenerationRequest for ResponsesRequest {
|
|
||||||
fn is_stream(&self) -> bool {
|
|
||||||
self.stream
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_model(&self) -> Option<&str> {
|
|
||||||
self.model.as_deref()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn extract_text_for_routing(&self) -> String {
|
|
||||||
match &self.input {
|
|
||||||
ResponseInput::Text(text) => text.clone(),
|
|
||||||
ResponseInput::Items(items) => items
|
|
||||||
.iter()
|
|
||||||
.filter_map(|item| match item {
|
|
||||||
ResponseInputOutputItem::Message { content, .. } => {
|
|
||||||
let texts: Vec<String> = content
|
|
||||||
.iter()
|
|
||||||
.map(|part| match part {
|
|
||||||
ResponseContentPart::OutputText { text, .. } => text.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
if texts.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(texts.join(" "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ResponseInputOutputItem::Reasoning { content, .. } => {
|
|
||||||
let texts: Vec<String> = content
|
|
||||||
.iter()
|
|
||||||
.map(|part| match part {
|
|
||||||
ResponseReasoningContent::ReasoningText { text } => text.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
if texts.is_empty() {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(texts.join(" "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ResponseInputOutputItem::FunctionToolCall { arguments, .. } => {
|
|
||||||
Some(arguments.clone())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Vec<String>>()
|
|
||||||
.join(" "),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,280 +0,0 @@
|
|||||||
// Responses API response types
|
|
||||||
|
|
||||||
use crate::protocols::openai::responses::request::ResponsesRequest;
|
|
||||||
use crate::protocols::openai::responses::types::*;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
fn generate_response_id() -> String {
|
|
||||||
format!("resp_{}", uuid::Uuid::new_v4().simple())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn current_timestamp() -> i64 {
|
|
||||||
std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_else(|_| std::time::Duration::from_secs(0))
|
|
||||||
.as_secs() as i64
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ResponsesResponse {
|
|
||||||
/// Response ID
|
|
||||||
#[serde(default = "generate_response_id")]
|
|
||||||
pub id: String,
|
|
||||||
|
|
||||||
/// Object type
|
|
||||||
#[serde(default = "default_object_type")]
|
|
||||||
pub object: String,
|
|
||||||
|
|
||||||
/// Creation timestamp
|
|
||||||
#[serde(default = "current_timestamp")]
|
|
||||||
pub created_at: i64,
|
|
||||||
|
|
||||||
/// Model name
|
|
||||||
pub model: String,
|
|
||||||
|
|
||||||
/// Output items
|
|
||||||
#[serde(default)]
|
|
||||||
pub output: Vec<ResponseOutputItem>,
|
|
||||||
|
|
||||||
/// Response status
|
|
||||||
pub status: ResponseStatus,
|
|
||||||
|
|
||||||
/// Usage information
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub usage: Option<UsageInfo>,
|
|
||||||
|
|
||||||
/// Whether parallel tool calls are enabled
|
|
||||||
#[serde(default = "default_true")]
|
|
||||||
pub parallel_tool_calls: bool,
|
|
||||||
|
|
||||||
/// Tool choice setting
|
|
||||||
#[serde(default = "default_tool_choice")]
|
|
||||||
pub tool_choice: String,
|
|
||||||
|
|
||||||
/// Available tools
|
|
||||||
#[serde(default)]
|
|
||||||
pub tools: Vec<ResponseTool>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_object_type() -> String {
|
|
||||||
"response".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_true() -> bool {
|
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_tool_choice() -> String {
|
|
||||||
"auto".to_string()
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponsesResponse {
|
|
||||||
/// Create a response from a request
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
pub fn from_request(
|
|
||||||
request: &ResponsesRequest,
|
|
||||||
_sampling_params: &HashMap<String, serde_json::Value>,
|
|
||||||
model_name: String,
|
|
||||||
created_time: i64,
|
|
||||||
output: Vec<ResponseOutputItem>,
|
|
||||||
status: ResponseStatus,
|
|
||||||
usage: Option<UsageInfo>,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
id: request.request_id.clone(),
|
|
||||||
object: "response".to_string(),
|
|
||||||
created_at: created_time,
|
|
||||||
model: model_name,
|
|
||||||
output,
|
|
||||||
status,
|
|
||||||
usage,
|
|
||||||
parallel_tool_calls: request.parallel_tool_calls,
|
|
||||||
tool_choice: match request.tool_choice {
|
|
||||||
ToolChoice::Auto => "auto".to_string(),
|
|
||||||
ToolChoice::Required => "required".to_string(),
|
|
||||||
ToolChoice::None => "none".to_string(),
|
|
||||||
},
|
|
||||||
tools: request.tools.clone(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new response with default values
|
|
||||||
pub fn new(request_id: String, model: String, status: ResponseStatus) -> Self {
|
|
||||||
Self {
|
|
||||||
id: request_id,
|
|
||||||
object: "response".to_string(),
|
|
||||||
created_at: current_timestamp(),
|
|
||||||
model,
|
|
||||||
output: Vec::new(),
|
|
||||||
status,
|
|
||||||
usage: None,
|
|
||||||
parallel_tool_calls: true,
|
|
||||||
tool_choice: "auto".to_string(),
|
|
||||||
tools: Vec::new(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Add an output item to the response
|
|
||||||
pub fn add_output(&mut self, item: ResponseOutputItem) {
|
|
||||||
self.output.push(item);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the usage information
|
|
||||||
pub fn set_usage(&mut self, usage: UsageInfo) {
|
|
||||||
self.usage = Some(usage);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update the status
|
|
||||||
pub fn set_status(&mut self, status: ResponseStatus) {
|
|
||||||
self.status = status;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the response is complete
|
|
||||||
pub fn is_complete(&self) -> bool {
|
|
||||||
matches!(self.status, ResponseStatus::Completed)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the response is in progress
|
|
||||||
pub fn is_in_progress(&self) -> bool {
|
|
||||||
matches!(self.status, ResponseStatus::InProgress)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the response failed
|
|
||||||
pub fn is_failed(&self) -> bool {
|
|
||||||
matches!(self.status, ResponseStatus::Failed)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the response was cancelled
|
|
||||||
pub fn is_cancelled(&self) -> bool {
|
|
||||||
matches!(self.status, ResponseStatus::Cancelled)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the response is queued
|
|
||||||
pub fn is_queued(&self) -> bool {
|
|
||||||
matches!(self.status, ResponseStatus::Queued)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert usage to OpenAI Responses API format
|
|
||||||
pub fn usage_in_response_format(
|
|
||||||
&self,
|
|
||||||
) -> Option<crate::protocols::openai::responses::types::ResponseUsage> {
|
|
||||||
self.usage.as_ref().map(|usage| usage.to_response_usage())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the response as a JSON value with usage in response format
|
|
||||||
pub fn to_response_format(&self) -> serde_json::Value {
|
|
||||||
let mut response = serde_json::to_value(self).unwrap_or(serde_json::Value::Null);
|
|
||||||
|
|
||||||
// Convert usage to response format if present
|
|
||||||
if let Some(usage) = &self.usage {
|
|
||||||
if let Ok(usage_value) = serde_json::to_value(usage.to_response_usage()) {
|
|
||||||
response["usage"] = usage_value;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
response
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Helper Functions =============
|
|
||||||
|
|
||||||
impl ResponseOutputItem {
|
|
||||||
/// Create a new message output item
|
|
||||||
pub fn new_message(
|
|
||||||
id: String,
|
|
||||||
role: String,
|
|
||||||
content: Vec<ResponseContentPart>,
|
|
||||||
status: String,
|
|
||||||
) -> Self {
|
|
||||||
Self::Message {
|
|
||||||
id,
|
|
||||||
role,
|
|
||||||
content,
|
|
||||||
status,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new reasoning output item
|
|
||||||
pub fn new_reasoning(
|
|
||||||
id: String,
|
|
||||||
summary: Vec<String>,
|
|
||||||
content: Vec<ResponseReasoningContent>,
|
|
||||||
status: Option<String>,
|
|
||||||
) -> Self {
|
|
||||||
Self::Reasoning {
|
|
||||||
id,
|
|
||||||
summary,
|
|
||||||
content,
|
|
||||||
status,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create a new function tool call output item
|
|
||||||
pub fn new_function_tool_call(
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
arguments: String,
|
|
||||||
output: Option<String>,
|
|
||||||
status: String,
|
|
||||||
) -> Self {
|
|
||||||
Self::FunctionToolCall {
|
|
||||||
id,
|
|
||||||
name,
|
|
||||||
arguments,
|
|
||||||
output,
|
|
||||||
status,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponseContentPart {
|
|
||||||
/// Create a new text content part
|
|
||||||
pub fn new_text(
|
|
||||||
text: String,
|
|
||||||
annotations: Vec<String>,
|
|
||||||
logprobs: Option<crate::protocols::openai::common::ChatLogProbs>,
|
|
||||||
) -> Self {
|
|
||||||
Self::OutputText {
|
|
||||||
text,
|
|
||||||
annotations,
|
|
||||||
logprobs,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponseReasoningContent {
|
|
||||||
/// Create a new reasoning text content
|
|
||||||
pub fn new_reasoning_text(text: String) -> Self {
|
|
||||||
Self::ReasoningText { text }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UsageInfo {
|
|
||||||
/// Create a new usage info with token counts
|
|
||||||
pub fn new(prompt_tokens: u32, completion_tokens: u32, reasoning_tokens: Option<u32>) -> Self {
|
|
||||||
Self {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens: prompt_tokens + completion_tokens,
|
|
||||||
reasoning_tokens,
|
|
||||||
prompt_tokens_details: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Create usage info with cached token details
|
|
||||||
pub fn new_with_cached(
|
|
||||||
prompt_tokens: u32,
|
|
||||||
completion_tokens: u32,
|
|
||||||
reasoning_tokens: Option<u32>,
|
|
||||||
cached_tokens: u32,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
prompt_tokens,
|
|
||||||
completion_tokens,
|
|
||||||
total_tokens: prompt_tokens + completion_tokens,
|
|
||||||
reasoning_tokens,
|
|
||||||
prompt_tokens_details: Some(PromptTokenUsageInfo { cached_tokens }),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,296 +0,0 @@
|
|||||||
// Supporting types for Responses API
|
|
||||||
|
|
||||||
use crate::protocols::openai::common::ChatLogProbs;
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
// ============= Tool Definitions =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ResponseTool {
|
|
||||||
#[serde(rename = "type")]
|
|
||||||
pub r#type: ResponseToolType,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ResponseToolType {
|
|
||||||
WebSearchPreview,
|
|
||||||
CodeInterpreter,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Reasoning Configuration =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ResponseReasoningParam {
|
|
||||||
#[serde(default = "default_reasoning_effort")]
|
|
||||||
pub effort: Option<ReasoningEffort>,
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_reasoning_effort() -> Option<ReasoningEffort> {
|
|
||||||
Some(ReasoningEffort::Medium)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ReasoningEffort {
|
|
||||||
Low,
|
|
||||||
Medium,
|
|
||||||
High,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Input/Output Items =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ResponseInputOutputItem {
|
|
||||||
#[serde(rename = "message")]
|
|
||||||
Message {
|
|
||||||
id: String,
|
|
||||||
role: String,
|
|
||||||
content: Vec<ResponseContentPart>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
status: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "reasoning")]
|
|
||||||
Reasoning {
|
|
||||||
id: String,
|
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
|
||||||
summary: Vec<String>,
|
|
||||||
content: Vec<ResponseReasoningContent>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
status: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "function_tool_call")]
|
|
||||||
FunctionToolCall {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
arguments: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
output: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
status: Option<String>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ResponseContentPart {
|
|
||||||
#[serde(rename = "output_text")]
|
|
||||||
OutputText {
|
|
||||||
text: String,
|
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
|
||||||
annotations: Vec<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
logprobs: Option<ChatLogProbs>,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ResponseReasoningContent {
|
|
||||||
#[serde(rename = "reasoning_text")]
|
|
||||||
ReasoningText { text: String },
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Output Items for Response =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(tag = "type")]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ResponseOutputItem {
|
|
||||||
#[serde(rename = "message")]
|
|
||||||
Message {
|
|
||||||
id: String,
|
|
||||||
role: String,
|
|
||||||
content: Vec<ResponseContentPart>,
|
|
||||||
status: String,
|
|
||||||
},
|
|
||||||
#[serde(rename = "reasoning")]
|
|
||||||
Reasoning {
|
|
||||||
id: String,
|
|
||||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
|
||||||
summary: Vec<String>,
|
|
||||||
content: Vec<ResponseReasoningContent>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
status: Option<String>,
|
|
||||||
},
|
|
||||||
#[serde(rename = "function_tool_call")]
|
|
||||||
FunctionToolCall {
|
|
||||||
id: String,
|
|
||||||
name: String,
|
|
||||||
arguments: String,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
output: Option<String>,
|
|
||||||
status: String,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Service Tier =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ServiceTier {
|
|
||||||
Auto,
|
|
||||||
Default,
|
|
||||||
Flex,
|
|
||||||
Scale,
|
|
||||||
Priority,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ServiceTier {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::Auto
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Tool Choice =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ToolChoice {
|
|
||||||
Auto,
|
|
||||||
Required,
|
|
||||||
None,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for ToolChoice {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::Auto
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Truncation =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum Truncation {
|
|
||||||
Auto,
|
|
||||||
Disabled,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for Truncation {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::Disabled
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Response Status =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum ResponseStatus {
|
|
||||||
Queued,
|
|
||||||
InProgress,
|
|
||||||
Completed,
|
|
||||||
Failed,
|
|
||||||
Cancelled,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Include Fields =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
#[serde(rename_all = "snake_case")]
|
|
||||||
pub enum IncludeField {
|
|
||||||
#[serde(rename = "code_interpreter_call.outputs")]
|
|
||||||
CodeInterpreterCallOutputs,
|
|
||||||
#[serde(rename = "computer_call_output.output.image_url")]
|
|
||||||
ComputerCallOutputImageUrl,
|
|
||||||
#[serde(rename = "file_search_call.results")]
|
|
||||||
FileSearchCallResults,
|
|
||||||
#[serde(rename = "message.input_image.image_url")]
|
|
||||||
MessageInputImageUrl,
|
|
||||||
#[serde(rename = "message.output_text.logprobs")]
|
|
||||||
MessageOutputTextLogprobs,
|
|
||||||
#[serde(rename = "reasoning.encrypted_content")]
|
|
||||||
ReasoningEncryptedContent,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Usage Info =============
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct UsageInfo {
|
|
||||||
pub prompt_tokens: u32,
|
|
||||||
pub completion_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub reasoning_tokens: Option<u32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub prompt_tokens_details: Option<PromptTokenUsageInfo>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct PromptTokenUsageInfo {
|
|
||||||
pub cached_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============= Response Usage Format =============
|
|
||||||
|
|
||||||
/// OpenAI Responses API usage format (different from standard UsageInfo)
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct ResponseUsage {
|
|
||||||
pub input_tokens: u32,
|
|
||||||
pub output_tokens: u32,
|
|
||||||
pub total_tokens: u32,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub input_tokens_details: Option<InputTokensDetails>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub output_tokens_details: Option<OutputTokensDetails>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct InputTokensDetails {
|
|
||||||
pub cached_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
|
||||||
pub struct OutputTokensDetails {
|
|
||||||
pub reasoning_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl UsageInfo {
|
|
||||||
/// Convert to OpenAI Responses API format
|
|
||||||
pub fn to_response_usage(&self) -> ResponseUsage {
|
|
||||||
ResponseUsage {
|
|
||||||
input_tokens: self.prompt_tokens,
|
|
||||||
output_tokens: self.completion_tokens,
|
|
||||||
total_tokens: self.total_tokens,
|
|
||||||
input_tokens_details: self.prompt_tokens_details.as_ref().map(|details| {
|
|
||||||
InputTokensDetails {
|
|
||||||
cached_tokens: details.cached_tokens,
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
output_tokens_details: self.reasoning_tokens.map(|tokens| OutputTokensDetails {
|
|
||||||
reasoning_tokens: tokens,
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<UsageInfo> for ResponseUsage {
|
|
||||||
fn from(usage: UsageInfo) -> Self {
|
|
||||||
usage.to_response_usage()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ResponseUsage {
|
|
||||||
/// Convert back to standard UsageInfo format
|
|
||||||
pub fn to_usage_info(&self) -> UsageInfo {
|
|
||||||
UsageInfo {
|
|
||||||
prompt_tokens: self.input_tokens,
|
|
||||||
completion_tokens: self.output_tokens,
|
|
||||||
total_tokens: self.total_tokens,
|
|
||||||
reasoning_tokens: self
|
|
||||||
.output_tokens_details
|
|
||||||
.as_ref()
|
|
||||||
.map(|details| details.reasoning_tokens),
|
|
||||||
prompt_tokens_details: self.input_tokens_details.as_ref().map(|details| {
|
|
||||||
PromptTokenUsageInfo {
|
|
||||||
cached_tokens: details.cached_tokens,
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
1867
sgl-router/src/protocols/spec.rs
Normal file
1867
sgl-router/src/protocols/spec.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,11 @@ use anyhow::Result;
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
// Import types from spec module
|
||||||
|
use crate::protocols::spec::{
|
||||||
|
ChatCompletionRequest, ChatMessage, ResponseFormat, StringOrArray, UserMessageContent,
|
||||||
|
};
|
||||||
|
|
||||||
/// Validation constants for OpenAI API parameters
|
/// Validation constants for OpenAI API parameters
|
||||||
pub mod constants {
|
pub mod constants {
|
||||||
/// Temperature range: 0.0 to 2.0 (OpenAI spec)
|
/// Temperature range: 0.0 to 2.0 (OpenAI spec)
|
||||||
@@ -257,7 +262,7 @@ pub mod utils {
|
|||||||
) -> Result<(), ValidationError> {
|
) -> Result<(), ValidationError> {
|
||||||
if let Some(stop) = request.get_stop_sequences() {
|
if let Some(stop) = request.get_stop_sequences() {
|
||||||
match stop {
|
match stop {
|
||||||
crate::protocols::common::StringOrArray::String(s) => {
|
StringOrArray::String(s) => {
|
||||||
if s.is_empty() {
|
if s.is_empty() {
|
||||||
return Err(ValidationError::InvalidValue {
|
return Err(ValidationError::InvalidValue {
|
||||||
parameter: "stop".to_string(),
|
parameter: "stop".to_string(),
|
||||||
@@ -266,7 +271,7 @@ pub mod utils {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
crate::protocols::common::StringOrArray::Array(arr) => {
|
StringOrArray::Array(arr) => {
|
||||||
validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?;
|
validate_max_items(arr, constants::MAX_STOP_SEQUENCES, "stop")?;
|
||||||
for (i, s) in arr.iter().enumerate() {
|
for (i, s) in arr.iter().enumerate() {
|
||||||
if s.is_empty() {
|
if s.is_empty() {
|
||||||
@@ -469,7 +474,7 @@ pub trait SamplingOptionsProvider {
|
|||||||
/// Trait for validating stop conditions
|
/// Trait for validating stop conditions
|
||||||
pub trait StopConditionsProvider {
|
pub trait StopConditionsProvider {
|
||||||
/// Get stop sequences
|
/// Get stop sequences
|
||||||
fn get_stop_sequences(&self) -> Option<&crate::protocols::common::StringOrArray>;
|
fn get_stop_sequences(&self) -> Option<&StringOrArray>;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait for validating token limits
|
/// Trait for validating token limits
|
||||||
@@ -532,28 +537,11 @@ pub trait ValidatableRequest:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
// ==================================================================
|
||||||
mod tests {
|
// = OPENAI CHAT COMPLETION VALIDATION =
|
||||||
use super::constants::*;
|
// ==================================================================
|
||||||
use super::utils::*;
|
|
||||||
use super::*;
|
|
||||||
use crate::protocols::common::StringOrArray;
|
|
||||||
|
|
||||||
// Mock request type for testing validation traits
|
impl SamplingOptionsProvider for ChatCompletionRequest {
|
||||||
#[derive(Debug, Default)]
|
|
||||||
struct MockRequest {
|
|
||||||
temperature: Option<f32>,
|
|
||||||
top_p: Option<f32>,
|
|
||||||
frequency_penalty: Option<f32>,
|
|
||||||
presence_penalty: Option<f32>,
|
|
||||||
stop: Option<StringOrArray>,
|
|
||||||
max_tokens: Option<u32>,
|
|
||||||
min_tokens: Option<u32>,
|
|
||||||
logprobs: Option<u32>,
|
|
||||||
top_logprobs: Option<u32>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SamplingOptionsProvider for MockRequest {
|
|
||||||
fn get_temperature(&self) -> Option<f32> {
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
self.temperature
|
self.temperature
|
||||||
}
|
}
|
||||||
@@ -568,6 +556,235 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl StopConditionsProvider for ChatCompletionRequest {
|
||||||
|
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
|
||||||
|
self.stop.as_ref()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TokenLimitsProvider for ChatCompletionRequest {
|
||||||
|
fn get_max_tokens(&self) -> Option<u32> {
|
||||||
|
// Prefer max_completion_tokens over max_tokens if both are set
|
||||||
|
self.max_completion_tokens.or(self.max_tokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_min_tokens(&self) -> Option<u32> {
|
||||||
|
self.min_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LogProbsProvider for ChatCompletionRequest {
|
||||||
|
fn get_logprobs(&self) -> Option<u32> {
|
||||||
|
// For chat API, logprobs is a boolean, return 1 if true for validation purposes
|
||||||
|
if self.logprobs {
|
||||||
|
Some(1)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_top_logprobs(&self) -> Option<u32> {
|
||||||
|
self.top_logprobs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SGLangExtensionsProvider for ChatCompletionRequest {
|
||||||
|
fn get_top_k(&self) -> Option<i32> {
|
||||||
|
self.top_k
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_min_p(&self) -> Option<f32> {
|
||||||
|
self.min_p
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_repetition_penalty(&self) -> Option<f32> {
|
||||||
|
self.repetition_penalty
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionCountProvider for ChatCompletionRequest {
|
||||||
|
fn get_n(&self) -> Option<u32> {
|
||||||
|
self.n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatCompletionRequest {
|
||||||
|
/// Validate message-specific requirements
|
||||||
|
pub fn validate_messages(&self) -> Result<(), ValidationError> {
|
||||||
|
// Ensure messages array is not empty
|
||||||
|
utils::validate_non_empty_array(&self.messages, "messages")?;
|
||||||
|
|
||||||
|
// Validate message content is not empty
|
||||||
|
for (i, msg) in self.messages.iter().enumerate() {
|
||||||
|
if let ChatMessage::User { content, .. } = msg {
|
||||||
|
match content {
|
||||||
|
UserMessageContent::Text(text) if text.is_empty() => {
|
||||||
|
return Err(ValidationError::InvalidValue {
|
||||||
|
parameter: format!("messages[{}].content", i),
|
||||||
|
value: "empty".to_string(),
|
||||||
|
reason: "message content cannot be empty".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
UserMessageContent::Parts(parts) if parts.is_empty() => {
|
||||||
|
return Err(ValidationError::InvalidValue {
|
||||||
|
parameter: format!("messages[{}].content", i),
|
||||||
|
value: "empty array".to_string(),
|
||||||
|
reason: "message content parts cannot be empty".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate response format if specified
|
||||||
|
pub fn validate_response_format(&self) -> Result<(), ValidationError> {
|
||||||
|
if let Some(ResponseFormat::JsonSchema { json_schema }) = &self.response_format {
|
||||||
|
if json_schema.name.is_empty() {
|
||||||
|
return Err(ValidationError::InvalidValue {
|
||||||
|
parameter: "response_format.json_schema.name".to_string(),
|
||||||
|
value: "empty".to_string(),
|
||||||
|
reason: "JSON schema name cannot be empty".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate chat API specific logprobs requirements
|
||||||
|
pub fn validate_chat_logprobs(&self) -> Result<(), ValidationError> {
|
||||||
|
// In chat API, if logprobs=true, top_logprobs must be specified
|
||||||
|
if self.logprobs && self.top_logprobs.is_none() {
|
||||||
|
return Err(ValidationError::MissingRequired {
|
||||||
|
parameter: "top_logprobs".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// If top_logprobs is specified, logprobs should be true
|
||||||
|
if self.top_logprobs.is_some() && !self.logprobs {
|
||||||
|
return Err(ValidationError::InvalidValue {
|
||||||
|
parameter: "logprobs".to_string(),
|
||||||
|
value: "false".to_string(),
|
||||||
|
reason: "must be true when top_logprobs is specified".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate cross-parameter relationships specific to chat completions
|
||||||
|
pub fn validate_chat_cross_parameters(&self) -> Result<(), ValidationError> {
|
||||||
|
// Validate that both max_tokens and max_completion_tokens aren't set
|
||||||
|
utils::validate_conflicting_parameters(
|
||||||
|
"max_tokens",
|
||||||
|
self.max_tokens.is_some(),
|
||||||
|
"max_completion_tokens",
|
||||||
|
self.max_completion_tokens.is_some(),
|
||||||
|
"cannot specify both max_tokens and max_completion_tokens",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Validate that tools and functions aren't both specified (deprecated)
|
||||||
|
utils::validate_conflicting_parameters(
|
||||||
|
"tools",
|
||||||
|
self.tools.is_some(),
|
||||||
|
"functions",
|
||||||
|
self.functions.is_some(),
|
||||||
|
"functions is deprecated, use tools instead",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Validate structured output constraints don't conflict with JSON response format
|
||||||
|
let has_json_format = matches!(
|
||||||
|
self.response_format,
|
||||||
|
Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. })
|
||||||
|
);
|
||||||
|
|
||||||
|
utils::validate_conflicting_parameters(
|
||||||
|
"response_format",
|
||||||
|
has_json_format,
|
||||||
|
"regex",
|
||||||
|
self.regex.is_some(),
|
||||||
|
"cannot use regex constraint with JSON response format",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
utils::validate_conflicting_parameters(
|
||||||
|
"response_format",
|
||||||
|
has_json_format,
|
||||||
|
"ebnf",
|
||||||
|
self.ebnf.is_some(),
|
||||||
|
"cannot use EBNF constraint with JSON response format",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Only one structured output constraint should be active
|
||||||
|
let structured_constraints = [
|
||||||
|
("regex", self.regex.is_some()),
|
||||||
|
("ebnf", self.ebnf.is_some()),
|
||||||
|
(
|
||||||
|
"json_schema",
|
||||||
|
matches!(
|
||||||
|
self.response_format,
|
||||||
|
Some(ResponseFormat::JsonSchema { .. })
|
||||||
|
),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
utils::validate_mutually_exclusive_options(
|
||||||
|
&structured_constraints,
|
||||||
|
"Only one structured output constraint (regex, ebnf, or json_schema) can be active at a time",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ValidatableRequest for ChatCompletionRequest {
|
||||||
|
fn validate(&self) -> Result<(), ValidationError> {
|
||||||
|
// Call the common validation function from the validation module
|
||||||
|
utils::validate_common_request_params(self)?;
|
||||||
|
|
||||||
|
// Then validate chat-specific parameters
|
||||||
|
self.validate_messages()?;
|
||||||
|
self.validate_response_format()?;
|
||||||
|
self.validate_chat_logprobs()?;
|
||||||
|
self.validate_chat_cross_parameters()?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::constants::*;
|
||||||
|
use super::utils::*;
|
||||||
|
use super::*;
|
||||||
|
use crate::protocols::spec::StringOrArray;
|
||||||
|
|
||||||
|
// Mock request type for testing validation traits
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct MockRequest {
|
||||||
|
temperature: Option<f32>,
|
||||||
|
stop: Option<StringOrArray>,
|
||||||
|
max_tokens: Option<u32>,
|
||||||
|
min_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SamplingOptionsProvider for MockRequest {
|
||||||
|
fn get_temperature(&self) -> Option<f32> {
|
||||||
|
self.temperature
|
||||||
|
}
|
||||||
|
fn get_top_p(&self) -> Option<f32> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
fn get_frequency_penalty(&self) -> Option<f32> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
fn get_presence_penalty(&self) -> Option<f32> {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl StopConditionsProvider for MockRequest {
|
impl StopConditionsProvider for MockRequest {
|
||||||
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
|
fn get_stop_sequences(&self) -> Option<&StringOrArray> {
|
||||||
self.stop.as_ref()
|
self.stop.as_ref()
|
||||||
@@ -585,97 +802,36 @@ mod tests {
|
|||||||
|
|
||||||
impl LogProbsProvider for MockRequest {
|
impl LogProbsProvider for MockRequest {
|
||||||
fn get_logprobs(&self) -> Option<u32> {
|
fn get_logprobs(&self) -> Option<u32> {
|
||||||
self.logprobs
|
None
|
||||||
}
|
}
|
||||||
fn get_top_logprobs(&self) -> Option<u32> {
|
fn get_top_logprobs(&self) -> Option<u32> {
|
||||||
self.top_logprobs
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SGLangExtensionsProvider for MockRequest {
|
impl SGLangExtensionsProvider for MockRequest {}
|
||||||
// Default implementations return None, so no custom logic needed
|
impl CompletionCountProvider for MockRequest {}
|
||||||
}
|
|
||||||
|
|
||||||
impl CompletionCountProvider for MockRequest {
|
|
||||||
// Default implementation returns None, so no custom logic needed
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ValidatableRequest for MockRequest {}
|
impl ValidatableRequest for MockRequest {}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_range_valid() {
|
fn test_range_validation() {
|
||||||
let result = validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature");
|
// Valid range
|
||||||
assert!(result.is_ok());
|
assert!(validate_range(1.5f32, &TEMPERATURE_RANGE, "temperature").is_ok());
|
||||||
assert_eq!(result.unwrap(), 1.5f32);
|
// Invalid range
|
||||||
|
assert!(validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature").is_err());
|
||||||
|
assert!(validate_range(3.0f32, &TEMPERATURE_RANGE, "temperature").is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validate_range_too_low() {
|
fn test_sglang_top_k_validation() {
|
||||||
let result = validate_range(-0.1f32, &TEMPERATURE_RANGE, "temperature");
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::OutOfRange { parameter, .. } => {
|
|
||||||
assert_eq!(parameter, "temperature");
|
|
||||||
}
|
|
||||||
_ => panic!("Expected OutOfRange error"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_positive_valid() {
|
|
||||||
let result = validate_positive(5i32, "max_tokens");
|
|
||||||
assert!(result.is_ok());
|
|
||||||
assert_eq!(result.unwrap(), 5i32);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_max_items_valid() {
|
|
||||||
let items = vec!["stop1", "stop2"];
|
|
||||||
let result = validate_max_items(&items, MAX_STOP_SEQUENCES, "stop");
|
|
||||||
assert!(result.is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_validate_top_k() {
|
|
||||||
assert!(validate_top_k(-1).is_ok()); // Disabled
|
assert!(validate_top_k(-1).is_ok()); // Disabled
|
||||||
assert!(validate_top_k(50).is_ok()); // Positive
|
assert!(validate_top_k(50).is_ok()); // Valid positive
|
||||||
assert!(validate_top_k(0).is_err()); // Invalid
|
assert!(validate_top_k(0).is_err()); // Invalid
|
||||||
assert!(validate_top_k(-5).is_err()); // Invalid
|
assert!(validate_top_k(-5).is_err()); // Invalid
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_valid_request() {
|
fn test_stop_sequences_limits() {
|
||||||
let request = MockRequest {
|
|
||||||
temperature: Some(1.0),
|
|
||||||
top_p: Some(0.9),
|
|
||||||
frequency_penalty: Some(0.5),
|
|
||||||
presence_penalty: Some(-0.5),
|
|
||||||
stop: Some(StringOrArray::Array(vec![
|
|
||||||
"stop1".to_string(),
|
|
||||||
"stop2".to_string(),
|
|
||||||
])),
|
|
||||||
max_tokens: Some(100),
|
|
||||||
min_tokens: Some(10),
|
|
||||||
logprobs: Some(3),
|
|
||||||
top_logprobs: Some(15),
|
|
||||||
};
|
|
||||||
|
|
||||||
assert!(request.validate().is_ok());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_invalid_temperature() {
|
|
||||||
let request = MockRequest {
|
|
||||||
temperature: Some(3.0), // Invalid: too high
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_too_many_stop_sequences() {
|
|
||||||
let request = MockRequest {
|
let request = MockRequest {
|
||||||
stop: Some(StringOrArray::Array(vec![
|
stop: Some(StringOrArray::Array(vec![
|
||||||
"stop1".to_string(),
|
"stop1".to_string(),
|
||||||
@@ -686,72 +842,322 @@ mod tests {
|
|||||||
])),
|
])),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
assert!(request.validate().is_err());
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::TooManyItems {
|
|
||||||
parameter,
|
|
||||||
count,
|
|
||||||
max,
|
|
||||||
} => {
|
|
||||||
assert_eq!(parameter, "stop");
|
|
||||||
assert_eq!(count, 5);
|
|
||||||
assert_eq!(max, MAX_STOP_SEQUENCES);
|
|
||||||
}
|
|
||||||
_ => panic!("Expected TooManyItems error"),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_conflicting_token_limits() {
|
fn test_token_limits_conflict() {
|
||||||
let request = MockRequest {
|
let request = MockRequest {
|
||||||
min_tokens: Some(100),
|
min_tokens: Some(100),
|
||||||
max_tokens: Some(50), // Invalid: min > max
|
max_tokens: Some(50), // min > max
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
assert!(request.validate().is_err());
|
||||||
let result = request.validate();
|
|
||||||
assert!(result.is_err());
|
|
||||||
match result.unwrap_err() {
|
|
||||||
ValidationError::ConflictingParameters {
|
|
||||||
parameter1,
|
|
||||||
parameter2,
|
|
||||||
..
|
|
||||||
} => {
|
|
||||||
assert_eq!(parameter1, "min_tokens");
|
|
||||||
assert_eq!(parameter2, "max_tokens");
|
|
||||||
}
|
}
|
||||||
_ => panic!("Expected ConflictingParameters error"),
|
|
||||||
|
#[test]
|
||||||
|
fn test_valid_request() {
|
||||||
|
let request = MockRequest {
|
||||||
|
temperature: Some(1.0),
|
||||||
|
stop: Some(StringOrArray::Array(vec!["stop".to_string()])),
|
||||||
|
max_tokens: Some(100),
|
||||||
|
min_tokens: Some(10),
|
||||||
|
};
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chat completion specific tests
|
||||||
|
#[cfg(test)]
|
||||||
|
mod chat_tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn create_valid_chat_request() -> ChatCompletionRequest {
|
||||||
|
ChatCompletionRequest {
|
||||||
|
model: "gpt-4".to_string(),
|
||||||
|
messages: vec![ChatMessage::User {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: UserMessageContent::Text("Hello".to_string()),
|
||||||
|
name: None,
|
||||||
|
}],
|
||||||
|
temperature: Some(1.0),
|
||||||
|
top_p: Some(0.9),
|
||||||
|
n: Some(1),
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
stop: None,
|
||||||
|
max_tokens: Some(100),
|
||||||
|
max_completion_tokens: None,
|
||||||
|
presence_penalty: Some(0.0),
|
||||||
|
frequency_penalty: Some(0.0),
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
logprobs: false,
|
||||||
|
top_logprobs: None,
|
||||||
|
response_format: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
parallel_tool_calls: None,
|
||||||
|
functions: None,
|
||||||
|
function_call: None,
|
||||||
|
// SGLang extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
continue_final_message: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
separate_reasoning: true,
|
||||||
|
stream_reasoning: true,
|
||||||
|
return_hidden_states: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_boundary_values() {
|
fn test_chat_validation_basics() {
|
||||||
let request = MockRequest {
|
// Valid request
|
||||||
temperature: Some(0.0), // Boundary: minimum
|
assert!(create_valid_chat_request().validate().is_ok());
|
||||||
top_p: Some(1.0), // Boundary: maximum
|
|
||||||
frequency_penalty: Some(-2.0), // Boundary: minimum
|
|
||||||
presence_penalty: Some(2.0), // Boundary: maximum
|
|
||||||
logprobs: Some(0), // Boundary: minimum
|
|
||||||
top_logprobs: Some(20), // Boundary: maximum
|
|
||||||
..Default::default()
|
|
||||||
};
|
|
||||||
|
|
||||||
|
// Empty messages
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
request.messages = vec![];
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Invalid temperature
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
request.temperature = Some(3.0);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_conflicts() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// Conflicting max_tokens
|
||||||
|
request.max_tokens = Some(100);
|
||||||
|
request.max_completion_tokens = Some(200);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Logprobs without top_logprobs
|
||||||
|
request.max_tokens = None;
|
||||||
|
request.logprobs = true;
|
||||||
|
request.top_logprobs = None;
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sglang_extensions() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// Valid SGLang parameters
|
||||||
|
request.top_k = Some(-1);
|
||||||
|
request.min_p = Some(0.1);
|
||||||
|
request.repetition_penalty = Some(1.2);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
// Invalid parameters
|
||||||
|
request.top_k = Some(0); // Invalid
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parameter_ranges() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// Test temperature range (0.0 to 2.0)
|
||||||
|
request.temperature = Some(1.5);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.temperature = Some(-0.1);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.temperature = Some(3.0);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Test top_p range (0.0 to 1.0)
|
||||||
|
request.temperature = Some(1.0); // Reset
|
||||||
|
request.top_p = Some(0.9);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.top_p = Some(-0.1);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.top_p = Some(1.5);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Test frequency_penalty range (-2.0 to 2.0)
|
||||||
|
request.top_p = Some(0.9); // Reset
|
||||||
|
request.frequency_penalty = Some(1.5);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.frequency_penalty = Some(-2.5);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.frequency_penalty = Some(3.0);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Test presence_penalty range (-2.0 to 2.0)
|
||||||
|
request.frequency_penalty = Some(0.0); // Reset
|
||||||
|
request.presence_penalty = Some(-1.5);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.presence_penalty = Some(-3.0);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.presence_penalty = Some(2.5);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Test repetition_penalty range (0.0 to 2.0)
|
||||||
|
request.presence_penalty = Some(0.0); // Reset
|
||||||
|
request.repetition_penalty = Some(1.2);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.repetition_penalty = Some(-0.1);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.repetition_penalty = Some(2.1);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Test min_p range (0.0 to 1.0)
|
||||||
|
request.repetition_penalty = Some(1.0); // Reset
|
||||||
|
request.min_p = Some(0.5);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.min_p = Some(-0.1);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.min_p = Some(1.5);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_structured_output_conflicts() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// JSON response format with regex should conflict
|
||||||
|
request.response_format = Some(ResponseFormat::JsonObject);
|
||||||
|
request.regex = Some(".*".to_string());
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// JSON response format with EBNF should conflict
|
||||||
|
request.regex = None;
|
||||||
|
request.ebnf = Some("grammar".to_string());
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Multiple structured constraints should conflict
|
||||||
|
request.response_format = None;
|
||||||
|
request.regex = Some(".*".to_string());
|
||||||
|
request.ebnf = Some("grammar".to_string());
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Only one constraint should work
|
||||||
|
request.ebnf = None;
|
||||||
|
request.regex = Some(".*".to_string());
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
request.regex = None;
|
||||||
|
request.ebnf = Some("grammar".to_string());
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
request.ebnf = None;
|
||||||
|
request.response_format = Some(ResponseFormat::JsonObject);
|
||||||
assert!(request.validate().is_ok());
|
assert!(request.validate().is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_validation_error_display() {
|
fn test_stop_sequences_validation() {
|
||||||
let error = ValidationError::OutOfRange {
|
let mut request = create_valid_chat_request();
|
||||||
parameter: "temperature".to_string(),
|
|
||||||
value: "3.0".to_string(),
|
|
||||||
min: "0.0".to_string(),
|
|
||||||
max: "2.0".to_string(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let message = format!("{}", error);
|
// Valid stop sequences
|
||||||
assert!(message.contains("temperature"));
|
request.stop = Some(StringOrArray::Array(vec![
|
||||||
assert!(message.contains("3.0"));
|
"stop1".to_string(),
|
||||||
|
"stop2".to_string(),
|
||||||
|
]));
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
// Too many stop sequences (max 4)
|
||||||
|
request.stop = Some(StringOrArray::Array(vec![
|
||||||
|
"stop1".to_string(),
|
||||||
|
"stop2".to_string(),
|
||||||
|
"stop3".to_string(),
|
||||||
|
"stop4".to_string(),
|
||||||
|
"stop5".to_string(),
|
||||||
|
]));
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Empty stop sequence should fail
|
||||||
|
request.stop = Some(StringOrArray::String("".to_string()));
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Empty string in array should fail
|
||||||
|
request.stop = Some(StringOrArray::Array(vec![
|
||||||
|
"stop1".to_string(),
|
||||||
|
"".to_string(),
|
||||||
|
]));
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_logprobs_validation() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// Valid logprobs configuration
|
||||||
|
request.logprobs = true;
|
||||||
|
request.top_logprobs = Some(10);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
// logprobs=true without top_logprobs should fail
|
||||||
|
request.top_logprobs = None;
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// top_logprobs without logprobs=true should fail
|
||||||
|
request.logprobs = false;
|
||||||
|
request.top_logprobs = Some(10);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// top_logprobs out of range (0-20)
|
||||||
|
request.logprobs = true;
|
||||||
|
request.top_logprobs = Some(25);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_n_parameter_validation() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// Valid n values (1-10)
|
||||||
|
request.n = Some(1);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.n = Some(5);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
request.n = Some(10);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
// Invalid n values
|
||||||
|
request.n = Some(0);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
request.n = Some(15);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_min_max_tokens_validation() {
|
||||||
|
let mut request = create_valid_chat_request();
|
||||||
|
|
||||||
|
// Valid token limits
|
||||||
|
request.min_tokens = Some(10);
|
||||||
|
request.max_tokens = Some(100);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
// min_tokens > max_tokens should fail
|
||||||
|
request.min_tokens = Some(150);
|
||||||
|
request.max_tokens = Some(100);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
|
||||||
|
// Should work with max_completion_tokens instead
|
||||||
|
request.max_tokens = None;
|
||||||
|
request.max_completion_tokens = Some(200);
|
||||||
|
request.min_tokens = Some(50);
|
||||||
|
assert!(request.validate().is_ok());
|
||||||
|
|
||||||
|
// min_tokens > max_completion_tokens should fail
|
||||||
|
request.min_tokens = Some(250);
|
||||||
|
assert!(request.validate().is_err());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,10 +9,7 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::protocols::{
|
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
generate::GenerateRequest,
|
|
||||||
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
|
|
||||||
};
|
|
||||||
|
|
||||||
pub mod factory;
|
pub mod factory;
|
||||||
pub mod header_utils;
|
pub mod header_utils;
|
||||||
|
|||||||
@@ -12,13 +12,9 @@ use crate::core::{
|
|||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
use crate::protocols::{
|
use crate::protocols::spec::{
|
||||||
common::StringOrArray,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, StringOrArray,
|
||||||
generate::GenerateRequest,
|
UserMessageContent,
|
||||||
openai::{
|
|
||||||
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
|
||||||
completions::CompletionRequest,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
|||||||
@@ -9,10 +9,8 @@ use crate::core::{
|
|||||||
};
|
};
|
||||||
use crate::metrics::RouterMetrics;
|
use crate::metrics::RouterMetrics;
|
||||||
use crate::policies::LoadBalancingPolicy;
|
use crate::policies::LoadBalancingPolicy;
|
||||||
use crate::protocols::{
|
use crate::protocols::spec::{
|
||||||
common::GenerationRequest,
|
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest,
|
||||||
generate::GenerateRequest,
|
|
||||||
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
|
|
||||||
};
|
};
|
||||||
use crate::routers::{RouterTrait, WorkerManagement};
|
use crate::routers::{RouterTrait, WorkerManagement};
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
use crate::config::RouterConfig;
|
use crate::config::RouterConfig;
|
||||||
use crate::logging::{self, LoggingConfig};
|
use crate::logging::{self, LoggingConfig};
|
||||||
use crate::metrics::{self, PrometheusConfig};
|
use crate::metrics::{self, PrometheusConfig};
|
||||||
use crate::protocols::{
|
use crate::protocols::spec::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
generate::GenerateRequest,
|
|
||||||
openai::{chat::ChatCompletionRequest, completions::CompletionRequest},
|
|
||||||
};
|
|
||||||
use crate::routers::{RouterFactory, RouterTrait};
|
use crate::routers::{RouterFactory, RouterTrait};
|
||||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|||||||
@@ -5,13 +5,9 @@
|
|||||||
|
|
||||||
use serde_json::{from_str, to_string, to_value};
|
use serde_json::{from_str, to_string, to_value};
|
||||||
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
use sglang_router_rs::core::{BasicWorker, WorkerType};
|
||||||
use sglang_router_rs::protocols::{
|
use sglang_router_rs::protocols::spec::{
|
||||||
common::StringOrArray,
|
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest,
|
||||||
generate::{GenerateParameters, GenerateRequest, SamplingParams},
|
SamplingParams, StringOrArray, UserMessageContent,
|
||||||
openai::{
|
|
||||||
chat::{ChatCompletionRequest, ChatMessage, UserMessageContent},
|
|
||||||
completions::CompletionRequest,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
// Integration test for Responses API
|
// Integration test for Responses API
|
||||||
|
|
||||||
use sglang_router_rs::protocols::common::GenerationRequest;
|
use sglang_router_rs::protocols::spec::{
|
||||||
use sglang_router_rs::protocols::openai::responses::request::ResponseInput;
|
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus,
|
||||||
use sglang_router_rs::protocols::openai::responses::*;
|
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice,
|
||||||
|
ToolChoiceValue, Truncation, UsageInfo,
|
||||||
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_responses_request_creation() {
|
fn test_responses_request_creation() {
|
||||||
@@ -24,7 +26,7 @@ fn test_responses_request_creation() {
|
|||||||
store: true,
|
store: true,
|
||||||
stream: false,
|
stream: false,
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
tool_choice: ToolChoice::Auto,
|
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||||
tools: vec![ResponseTool {
|
tools: vec![ResponseTool {
|
||||||
r#type: ResponseToolType::WebSearchPreview,
|
r#type: ResponseToolType::WebSearchPreview,
|
||||||
}],
|
}],
|
||||||
@@ -67,7 +69,7 @@ fn test_sampling_params_conversion() {
|
|||||||
store: true, // Use default true
|
store: true, // Use default true
|
||||||
stream: false,
|
stream: false,
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
tool_choice: ToolChoice::Auto,
|
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||||
tools: vec![],
|
tools: vec![],
|
||||||
top_logprobs: 0, // Use default 0
|
top_logprobs: 0, // Use default 0
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
@@ -177,7 +179,7 @@ fn test_json_serialization() {
|
|||||||
store: false,
|
store: false,
|
||||||
stream: true,
|
stream: true,
|
||||||
temperature: Some(0.9),
|
temperature: Some(0.9),
|
||||||
tool_choice: ToolChoice::Required,
|
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
|
||||||
tools: vec![ResponseTool {
|
tools: vec![ResponseTool {
|
||||||
r#type: ResponseToolType::CodeInterpreter,
|
r#type: ResponseToolType::CodeInterpreter,
|
||||||
}],
|
}],
|
||||||
|
|||||||
Reference in New Issue
Block a user