[router] complete router oai spec (#8828)
This commit is contained in:
@@ -8,12 +8,116 @@ use sglang_router_rs::openai_api_types::{
|
|||||||
};
|
};
|
||||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
||||||
|
|
||||||
|
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||||
|
fn default_generate_request() -> GenerateRequest {
|
||||||
|
GenerateRequest {
|
||||||
|
text: None,
|
||||||
|
prompt: None,
|
||||||
|
input_ids: None,
|
||||||
|
stream: false,
|
||||||
|
parameters: None,
|
||||||
|
sampling_params: None,
|
||||||
|
return_logprob: false,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
rid: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
|
||||||
|
fn default_chat_completion_request() -> ChatCompletionRequest {
|
||||||
|
ChatCompletionRequest {
|
||||||
|
model: String::new(),
|
||||||
|
messages: vec![],
|
||||||
|
max_tokens: None,
|
||||||
|
max_completion_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
logit_bias: None,
|
||||||
|
logprobs: false,
|
||||||
|
top_logprobs: None,
|
||||||
|
user: None,
|
||||||
|
response_format: None,
|
||||||
|
seed: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
parallel_tool_calls: None,
|
||||||
|
function_call: None,
|
||||||
|
functions: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
continue_final_message: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
separate_reasoning: true,
|
||||||
|
stream_reasoning: true,
|
||||||
|
return_hidden_states: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a default CompletionRequest for benchmarks with minimal fields set
|
||||||
|
fn default_completion_request() -> CompletionRequest {
|
||||||
|
CompletionRequest {
|
||||||
|
model: String::new(),
|
||||||
|
prompt: StringOrArray::String(String::new()),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
json_schema: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Sample request data for benchmarks
|
// Sample request data for benchmarks
|
||||||
fn create_sample_generate_request() -> GenerateRequest {
|
fn create_sample_generate_request() -> GenerateRequest {
|
||||||
GenerateRequest {
|
GenerateRequest {
|
||||||
text: Some("Write a story about artificial intelligence".to_string()),
|
text: Some("Write a story about artificial intelligence".to_string()),
|
||||||
input_ids: None,
|
|
||||||
prompt: None,
|
|
||||||
parameters: Some(GenerateParameters {
|
parameters: Some(GenerateParameters {
|
||||||
max_new_tokens: Some(100),
|
max_new_tokens: Some(100),
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
@@ -31,8 +135,7 @@ fn create_sample_generate_request() -> GenerateRequest {
|
|||||||
repetition_penalty: Some(1.0),
|
repetition_penalty: Some(1.0),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
stream: false,
|
..default_generate_request()
|
||||||
return_logprob: false,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,22 +161,10 @@ fn create_sample_chat_completion_request() -> ChatCompletionRequest {
|
|||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.0),
|
presence_penalty: Some(0.0),
|
||||||
frequency_penalty: Some(0.0),
|
frequency_penalty: Some(0.0),
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
user: None,
|
|
||||||
response_format: None,
|
|
||||||
seed: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
function_call: None,
|
..default_chat_completion_request()
|
||||||
functions: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -81,23 +172,14 @@ fn create_sample_completion_request() -> CompletionRequest {
|
|||||||
CompletionRequest {
|
CompletionRequest {
|
||||||
model: "text-davinci-003".to_string(),
|
model: "text-davinci-003".to_string(),
|
||||||
prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()),
|
prompt: StringOrArray::String("Complete this sentence: The future of AI is".to_string()),
|
||||||
suffix: None,
|
|
||||||
max_tokens: Some(50),
|
max_tokens: Some(50),
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.0),
|
presence_penalty: Some(0.0),
|
||||||
frequency_penalty: Some(0.0),
|
frequency_penalty: Some(0.0),
|
||||||
best_of: Some(1),
|
best_of: Some(1),
|
||||||
logit_bias: None,
|
..default_completion_request()
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
other: serde_json::Map::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -121,6 +203,7 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
|
|||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
function_call: None,
|
||||||
|
reasoning_content: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -132,22 +215,13 @@ fn create_large_chat_completion_request() -> ChatCompletionRequest {
|
|||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_p: Some(0.95),
|
top_p: Some(0.95),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.1),
|
presence_penalty: Some(0.1),
|
||||||
frequency_penalty: Some(0.1),
|
frequency_penalty: Some(0.1),
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: Some(5),
|
top_logprobs: Some(5),
|
||||||
user: Some("benchmark_user".to_string()),
|
user: Some("benchmark_user".to_string()),
|
||||||
response_format: None,
|
|
||||||
seed: Some(42),
|
seed: Some(42),
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
function_call: None,
|
..default_chat_completion_request()
|
||||||
functions: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -331,32 +405,17 @@ fn bench_throughput_by_size(c: &mut Criterion) {
|
|||||||
// Create requests of different sizes
|
// Create requests of different sizes
|
||||||
let small_generate = GenerateRequest {
|
let small_generate = GenerateRequest {
|
||||||
text: Some("Hi".to_string()),
|
text: Some("Hi".to_string()),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let medium_generate = GenerateRequest {
|
let medium_generate = GenerateRequest {
|
||||||
text: Some("Write a medium length story about AI".repeat(10)),
|
text: Some("Write a medium length story about AI".repeat(10)),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let large_generate = GenerateRequest {
|
let large_generate = GenerateRequest {
|
||||||
text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)),
|
text: Some("Write a very long and detailed story about artificial intelligence and its impact on society".repeat(100)),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
for (name, req) in [
|
for (name, req) in [
|
||||||
|
|||||||
@@ -6,6 +6,21 @@ use serde::{Deserialize, Serialize};
|
|||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Helper function for serde default value
|
||||||
|
fn default_true() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= SGLang-Specific Types =============
|
||||||
|
|
||||||
|
/// LoRA adapter path - can be single path or batch of paths
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum LoRAPath {
|
||||||
|
Single(Option<String>),
|
||||||
|
Batch(Vec<Option<String>>),
|
||||||
|
}
|
||||||
|
|
||||||
/// Common trait for all generation requests
|
/// Common trait for all generation requests
|
||||||
pub trait GenerationRequest: Send + Sync {
|
pub trait GenerationRequest: Send + Sync {
|
||||||
/// Check if the request is for streaming
|
/// Check if the request is for streaming
|
||||||
@@ -92,6 +107,64 @@ pub struct CompletionRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub seed: Option<i64>,
|
pub seed: Option<i64>,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Top-k sampling parameter (-1 to disable)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
|
||||||
|
/// Min-p nucleus sampling parameter
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Minimum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Repetition penalty for reducing repetitive text
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Regex constraint for output generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub regex: Option<String>,
|
||||||
|
|
||||||
|
/// EBNF grammar constraint for structured output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ebnf: Option<String>,
|
||||||
|
|
||||||
|
/// JSON schema constraint for structured output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<String>,
|
||||||
|
|
||||||
|
/// Specific token IDs to use as stop conditions
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
|
|
||||||
|
/// Skip trimming stop tokens from output
|
||||||
|
#[serde(default)]
|
||||||
|
pub no_stop_trim: bool,
|
||||||
|
|
||||||
|
/// Ignore end-of-sequence tokens during generation
|
||||||
|
#[serde(default)]
|
||||||
|
pub ignore_eos: bool,
|
||||||
|
|
||||||
|
/// Skip special tokens during detokenization
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub skip_special_tokens: bool,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<LoRAPath>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||||
|
|
||||||
|
/// Return model hidden states
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
/// Additional fields including bootstrap info for PD routing
|
/// Additional fields including bootstrap info for PD routing
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub other: serde_json::Map<String, serde_json::Value>,
|
pub other: serde_json::Map<String, serde_json::Value>,
|
||||||
@@ -166,7 +239,7 @@ pub struct ChatCompletionRequest {
|
|||||||
|
|
||||||
/// Modify the likelihood of specified tokens appearing in the completion
|
/// Modify the likelihood of specified tokens appearing in the completion
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub logit_bias: Option<HashMap<String, i32>>,
|
pub logit_bias: Option<HashMap<String, f32>>,
|
||||||
|
|
||||||
/// A unique identifier representing your end-user
|
/// A unique identifier representing your end-user
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
@@ -207,6 +280,72 @@ pub struct ChatCompletionRequest {
|
|||||||
/// Deprecated: use tool_choice instead
|
/// Deprecated: use tool_choice instead
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub function_call: Option<FunctionCall>,
|
pub function_call: Option<FunctionCall>,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Top-k sampling parameter (-1 to disable)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
|
||||||
|
/// Min-p nucleus sampling parameter
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
|
||||||
|
/// Minimum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Repetition penalty for reducing repetitive text
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Regex constraint for output generation
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub regex: Option<String>,
|
||||||
|
|
||||||
|
/// EBNF grammar constraint for structured output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ebnf: Option<String>,
|
||||||
|
|
||||||
|
/// Specific token IDs to use as stop conditions
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
|
|
||||||
|
/// Skip trimming stop tokens from output
|
||||||
|
#[serde(default)]
|
||||||
|
pub no_stop_trim: bool,
|
||||||
|
|
||||||
|
/// Ignore end-of-sequence tokens during generation
|
||||||
|
#[serde(default)]
|
||||||
|
pub ignore_eos: bool,
|
||||||
|
|
||||||
|
/// Continue generating from final assistant message
|
||||||
|
#[serde(default)]
|
||||||
|
pub continue_final_message: bool,
|
||||||
|
|
||||||
|
/// Skip special tokens during detokenization
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub skip_special_tokens: bool,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<LoRAPath>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||||
|
|
||||||
|
/// Separate reasoning content from final answer (O1-style models)
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub separate_reasoning: bool,
|
||||||
|
|
||||||
|
/// Stream reasoning tokens during generation
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub stream_reasoning: bool,
|
||||||
|
|
||||||
|
/// Return model hidden states
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_hidden_states: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -234,6 +373,9 @@ pub enum ChatMessage {
|
|||||||
tool_calls: Option<Vec<ToolCall>>,
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
function_call: Option<FunctionCallResponse>,
|
function_call: Option<FunctionCallResponse>,
|
||||||
|
/// Reasoning content for O1-style models (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
},
|
},
|
||||||
Tool {
|
Tool {
|
||||||
role: String, // "tool"
|
role: String, // "tool"
|
||||||
@@ -378,7 +520,20 @@ impl GenerationRequest for ChatCompletionRequest {
|
|||||||
Some(texts.join(" "))
|
Some(texts.join(" "))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
ChatMessage::Assistant { content, .. } => content.clone(),
|
ChatMessage::Assistant {
|
||||||
|
content,
|
||||||
|
reasoning_content,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
// Combine content and reasoning content for routing decisions
|
||||||
|
let main_content = content.clone().unwrap_or_default();
|
||||||
|
let reasoning = reasoning_content.clone().unwrap_or_default();
|
||||||
|
if main_content.is_empty() && reasoning.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(format!("{} {}", main_content, reasoning).trim().to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
||||||
ChatMessage::Function { content, .. } => Some(content.clone()),
|
ChatMessage::Function { content, .. } => Some(content.clone()),
|
||||||
})
|
})
|
||||||
@@ -418,6 +573,23 @@ pub struct GenerateRequest {
|
|||||||
/// Whether to return logprobs
|
/// Whether to return logprobs
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub return_logprob: bool,
|
pub return_logprob: bool,
|
||||||
|
|
||||||
|
// ============= SGLang Extensions =============
|
||||||
|
/// Path to LoRA adapter(s) for model customization
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub lora_path: Option<LoRAPath>,
|
||||||
|
|
||||||
|
/// Session parameters for continual prompting
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub session_params: Option<HashMap<String, serde_json::Value>>,
|
||||||
|
|
||||||
|
/// Return model hidden states
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_hidden_states: bool,
|
||||||
|
|
||||||
|
/// Request ID for tracking
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub rid: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -485,6 +657,18 @@ pub struct SamplingParams {
|
|||||||
pub skip_special_tokens: Option<bool>,
|
pub skip_special_tokens: Option<bool>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub json_schema: Option<String>,
|
pub json_schema: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub regex: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ebnf: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop_token_ids: Option<Vec<i32>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub no_stop_trim: Option<bool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl GenerationRequest for GenerateRequest {
|
impl GenerationRequest for GenerateRequest {
|
||||||
@@ -561,6 +745,12 @@ pub struct CompletionChoice {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub logprobs: Option<LogProbs>,
|
pub logprobs: Option<LogProbs>,
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||||
|
/// Information about which stop condition was matched
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
||||||
|
/// Hidden states from the model (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -591,6 +781,12 @@ pub struct ChatChoice {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub logprobs: Option<ChatLogProbs>,
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||||
|
/// Information about which stop condition was matched
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub matched_stop: Option<serde_json::Value>, // Can be string or integer
|
||||||
|
/// Hidden states from the model (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hidden_states: Option<Vec<f32>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
@@ -681,6 +877,9 @@ pub struct ChatMessageDelta {
|
|||||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub function_call: Option<FunctionCallDelta>,
|
pub function_call: Option<FunctionCallDelta>,
|
||||||
|
/// Reasoning content delta for O1-style models (SGLang extension)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_content: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
|||||||
@@ -278,11 +278,11 @@ mod bootstrap_tests {
|
|||||||
use crate::core::BasicWorker;
|
use crate::core::BasicWorker;
|
||||||
use crate::openai_api_types::StringOrArray;
|
use crate::openai_api_types::StringOrArray;
|
||||||
|
|
||||||
#[test]
|
/// Create a default CompletionRequest for testing with minimal fields set
|
||||||
fn test_completion_batch_size_with_array_prompt() {
|
fn default_completion_request() -> CompletionRequest {
|
||||||
let req = CompletionRequest {
|
CompletionRequest {
|
||||||
model: "test".to_string(),
|
model: String::new(),
|
||||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
prompt: StringOrArray::String(String::new()),
|
||||||
n: None,
|
n: None,
|
||||||
other: serde_json::Map::new(),
|
other: serde_json::Map::new(),
|
||||||
suffix: None,
|
suffix: None,
|
||||||
@@ -300,6 +300,31 @@ mod bootstrap_tests {
|
|||||||
logit_bias: None,
|
logit_bias: None,
|
||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
json_schema: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_batch_size_with_array_prompt() {
|
||||||
|
let req = CompletionRequest {
|
||||||
|
model: "test".to_string(),
|
||||||
|
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||||
|
..default_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Should return batch size for array prompt
|
// Should return batch size for array prompt
|
||||||
@@ -311,23 +336,7 @@ mod bootstrap_tests {
|
|||||||
let req = CompletionRequest {
|
let req = CompletionRequest {
|
||||||
model: "test".to_string(),
|
model: "test".to_string(),
|
||||||
prompt: StringOrArray::String("single prompt".to_string()),
|
prompt: StringOrArray::String("single prompt".to_string()),
|
||||||
n: None,
|
..default_completion_request()
|
||||||
other: serde_json::Map::new(),
|
|
||||||
suffix: None,
|
|
||||||
max_tokens: None,
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Should return None for single prompt
|
// Should return None for single prompt
|
||||||
@@ -340,22 +349,7 @@ mod bootstrap_tests {
|
|||||||
model: "test".to_string(),
|
model: "test".to_string(),
|
||||||
prompt: StringOrArray::String("single prompt".to_string()),
|
prompt: StringOrArray::String("single prompt".to_string()),
|
||||||
n: Some(3),
|
n: Some(3),
|
||||||
other: serde_json::Map::new(),
|
..default_completion_request()
|
||||||
suffix: None,
|
|
||||||
max_tokens: None,
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Should return None for single string prompt, even with n > 1
|
// Should return None for single string prompt, even with n > 1
|
||||||
@@ -368,23 +362,7 @@ mod bootstrap_tests {
|
|||||||
let mut req = CompletionRequest {
|
let mut req = CompletionRequest {
|
||||||
model: "test".to_string(),
|
model: "test".to_string(),
|
||||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||||
n: None,
|
..default_completion_request()
|
||||||
other: serde_json::Map::new(),
|
|
||||||
suffix: None,
|
|
||||||
max_tokens: None,
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set bootstrap info - should always use single values
|
// Set bootstrap info - should always use single values
|
||||||
@@ -418,23 +396,7 @@ mod bootstrap_tests {
|
|||||||
let mut req = CompletionRequest {
|
let mut req = CompletionRequest {
|
||||||
model: "test".to_string(),
|
model: "test".to_string(),
|
||||||
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]),
|
||||||
n: None,
|
..default_completion_request()
|
||||||
other: serde_json::Map::new(),
|
|
||||||
suffix: None,
|
|
||||||
max_tokens: None,
|
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Set bootstrap info with arrays
|
// Set bootstrap info with arrays
|
||||||
|
|||||||
@@ -176,6 +176,33 @@ impl ToPdRequest for CompletionRequest {
|
|||||||
self.stream => "stream"
|
self.stream => "stream"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Add SGLang extension fields
|
||||||
|
insert_if_some!(other,
|
||||||
|
// SGLang Extensions - Priority 1
|
||||||
|
self.top_k => "top_k",
|
||||||
|
self.min_p => "min_p",
|
||||||
|
self.min_tokens => "min_tokens",
|
||||||
|
self.repetition_penalty => "repetition_penalty",
|
||||||
|
self.regex => "regex",
|
||||||
|
self.ebnf => "ebnf",
|
||||||
|
self.stop_token_ids => "stop_token_ids",
|
||||||
|
// SGLang Extensions - Priority 2
|
||||||
|
self.lora_path => "lora_path",
|
||||||
|
self.session_params => "session_params"
|
||||||
|
);
|
||||||
|
|
||||||
|
// SGLang boolean extensions (CompletionRequest has these as bool, not Option<bool>)
|
||||||
|
other.insert("no_stop_trim".to_string(), self.no_stop_trim.into());
|
||||||
|
other.insert("ignore_eos".to_string(), self.ignore_eos.into());
|
||||||
|
other.insert(
|
||||||
|
"skip_special_tokens".to_string(),
|
||||||
|
self.skip_special_tokens.into(),
|
||||||
|
);
|
||||||
|
other.insert(
|
||||||
|
"return_hidden_states".to_string(),
|
||||||
|
self.return_hidden_states.into(),
|
||||||
|
);
|
||||||
|
|
||||||
GenerateReqInput {
|
GenerateReqInput {
|
||||||
text,
|
text,
|
||||||
input_ids: None,
|
input_ids: None,
|
||||||
@@ -226,14 +253,46 @@ impl ToPdRequest for ChatCompletionRequest {
|
|||||||
self.tool_choice => "tool_choice",
|
self.tool_choice => "tool_choice",
|
||||||
self.parallel_tool_calls => "parallel_tool_calls",
|
self.parallel_tool_calls => "parallel_tool_calls",
|
||||||
self.functions => "functions",
|
self.functions => "functions",
|
||||||
self.function_call => "function_call"
|
self.function_call => "function_call",
|
||||||
|
// SGLang Extensions - Priority 1
|
||||||
|
self.top_k => "top_k",
|
||||||
|
self.min_p => "min_p",
|
||||||
|
self.min_tokens => "min_tokens",
|
||||||
|
self.repetition_penalty => "repetition_penalty",
|
||||||
|
self.regex => "regex",
|
||||||
|
self.ebnf => "ebnf",
|
||||||
|
self.stop_token_ids => "stop_token_ids",
|
||||||
|
// SGLang Extensions - Priority 2
|
||||||
|
self.lora_path => "lora_path",
|
||||||
|
self.session_params => "session_params"
|
||||||
);
|
);
|
||||||
|
|
||||||
// Handle boolean logprobs flag
|
// Handle boolean flags
|
||||||
if self.logprobs {
|
if self.logprobs {
|
||||||
other.insert("logprobs".to_string(), true.into());
|
other.insert("logprobs".to_string(), true.into());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
|
||||||
|
other.insert("no_stop_trim".to_string(), self.no_stop_trim.into());
|
||||||
|
other.insert("ignore_eos".to_string(), self.ignore_eos.into());
|
||||||
|
other.insert(
|
||||||
|
"continue_final_message".to_string(),
|
||||||
|
self.continue_final_message.into(),
|
||||||
|
);
|
||||||
|
other.insert(
|
||||||
|
"skip_special_tokens".to_string(),
|
||||||
|
self.skip_special_tokens.into(),
|
||||||
|
);
|
||||||
|
other.insert(
|
||||||
|
"separate_reasoning".to_string(),
|
||||||
|
self.separate_reasoning.into(),
|
||||||
|
);
|
||||||
|
other.insert("stream_reasoning".to_string(), self.stream_reasoning.into());
|
||||||
|
other.insert(
|
||||||
|
"return_hidden_states".to_string(),
|
||||||
|
self.return_hidden_states.into(),
|
||||||
|
);
|
||||||
|
|
||||||
ChatReqInput {
|
ChatReqInput {
|
||||||
stream: self.stream,
|
stream: self.stream,
|
||||||
bootstrap_host: None,
|
bootstrap_host: None,
|
||||||
@@ -271,18 +330,136 @@ mod tests {
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
// ============= GenerateRequest to_pd_request Tests =============
|
// ============= Test Helper Functions =============
|
||||||
|
//
|
||||||
|
// These helper functions create default request instances with all required SGLang extension fields
|
||||||
|
// properly initialized. Use the struct spread operator `..default_*_request()` to override only
|
||||||
|
// the fields you need for specific tests, avoiding repetitive boilerplate code.
|
||||||
|
//
|
||||||
|
// Example usage:
|
||||||
|
// let req = GenerateRequest {
|
||||||
|
// text: Some("Custom text".to_string()),
|
||||||
|
// stream: true,
|
||||||
|
// ..default_generate_request()
|
||||||
|
// };
|
||||||
|
|
||||||
#[test]
|
/// Create a default GenerateRequest with minimal fields set
|
||||||
fn test_generate_to_pd_request_with_text_only() {
|
fn default_generate_request() -> GenerateRequest {
|
||||||
let req = GenerateRequest {
|
GenerateRequest {
|
||||||
text: Some("Hello world".to_string()),
|
text: None,
|
||||||
prompt: None,
|
prompt: None,
|
||||||
input_ids: None,
|
input_ids: None,
|
||||||
stream: false,
|
stream: false,
|
||||||
parameters: None,
|
parameters: None,
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
rid: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a default CompletionRequest with minimal fields set
|
||||||
|
fn default_completion_request() -> CompletionRequest {
|
||||||
|
CompletionRequest {
|
||||||
|
model: "test-model".to_string(),
|
||||||
|
prompt: StringOrArray::String("test prompt".to_string()),
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
suffix: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
json_schema: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a default ChatCompletionRequest with minimal fields set
|
||||||
|
fn default_chat_completion_request() -> ChatCompletionRequest {
|
||||||
|
ChatCompletionRequest {
|
||||||
|
model: "test-model".to_string(),
|
||||||
|
messages: vec![ChatMessage::User {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: UserMessageContent::Text("test message".to_string()),
|
||||||
|
name: None,
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
stop: None,
|
||||||
|
max_tokens: None,
|
||||||
|
max_completion_tokens: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
logit_bias: None,
|
||||||
|
logprobs: false,
|
||||||
|
top_logprobs: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
response_format: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
parallel_tool_calls: None,
|
||||||
|
functions: None,
|
||||||
|
function_call: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
continue_final_message: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
separate_reasoning: true,
|
||||||
|
stream_reasoning: true,
|
||||||
|
return_hidden_states: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= GenerateRequest to_pd_request Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_to_pd_request_with_text_only() {
|
||||||
|
let req = GenerateRequest {
|
||||||
|
text: Some("Hello world".to_string()),
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -308,13 +485,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_generate_to_pd_request_with_prompt_string() {
|
fn test_generate_to_pd_request_with_prompt_string() {
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: None,
|
|
||||||
prompt: Some(StringOrArray::String("Test prompt".to_string())),
|
prompt: Some(StringOrArray::String("Test prompt".to_string())),
|
||||||
input_ids: None,
|
|
||||||
stream: true,
|
stream: true,
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: true,
|
return_logprob: true,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -342,6 +516,7 @@ mod tests {
|
|||||||
parameters: None,
|
parameters: None,
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -360,13 +535,8 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_generate_to_pd_request_with_single_input_ids() {
|
fn test_generate_to_pd_request_with_single_input_ids() {
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: None,
|
|
||||||
prompt: None,
|
|
||||||
input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])),
|
input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])),
|
||||||
stream: false,
|
..default_generate_request()
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -381,17 +551,12 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_generate_to_pd_request_with_batch_input_ids() {
|
fn test_generate_to_pd_request_with_batch_input_ids() {
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: None,
|
|
||||||
prompt: None,
|
|
||||||
input_ids: Some(InputIds::Batch(vec![
|
input_ids: Some(InputIds::Batch(vec![
|
||||||
vec![1, 2, 3],
|
vec![1, 2, 3],
|
||||||
vec![4, 5, 6, 7],
|
vec![4, 5, 6, 7],
|
||||||
vec![8, 9],
|
vec![8, 9],
|
||||||
])),
|
])),
|
||||||
stream: false,
|
..default_generate_request()
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -413,10 +578,7 @@ mod tests {
|
|||||||
text: Some("SGLang text".to_string()),
|
text: Some("SGLang text".to_string()),
|
||||||
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
|
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
|
||||||
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
|
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
|
||||||
stream: false,
|
..default_generate_request()
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -429,13 +591,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_generate_to_pd_request_priority_prompt_over_input_ids() {
|
fn test_generate_to_pd_request_priority_prompt_over_input_ids() {
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: None,
|
|
||||||
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
|
prompt: Some(StringOrArray::String("OpenAI prompt".to_string())),
|
||||||
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
|
input_ids: Some(InputIds::Single(vec![1, 2, 3])),
|
||||||
stream: false,
|
..default_generate_request()
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -459,12 +617,8 @@ mod tests {
|
|||||||
|
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: Some("test".to_string()),
|
text: Some("test".to_string()),
|
||||||
prompt: None,
|
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
parameters: Some(params),
|
parameters: Some(params),
|
||||||
sampling_params: None,
|
..default_generate_request()
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -497,12 +651,8 @@ mod tests {
|
|||||||
|
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: Some("test".to_string()),
|
text: Some("test".to_string()),
|
||||||
prompt: None,
|
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: Some(sampling),
|
sampling_params: Some(sampling),
|
||||||
return_logprob: false,
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -546,6 +696,7 @@ mod tests {
|
|||||||
parameters: Some(params),
|
parameters: Some(params),
|
||||||
sampling_params: Some(sampling),
|
sampling_params: Some(sampling),
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -568,6 +719,7 @@ mod tests {
|
|||||||
parameters: Some(params),
|
parameters: Some(params),
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -603,6 +755,7 @@ mod tests {
|
|||||||
parameters: Some(params),
|
parameters: Some(params),
|
||||||
sampling_params: Some(sampling),
|
sampling_params: Some(sampling),
|
||||||
return_logprob: true,
|
return_logprob: true,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -632,23 +785,7 @@ mod tests {
|
|||||||
let req = CompletionRequest {
|
let req = CompletionRequest {
|
||||||
model: "gpt-3.5-turbo".to_string(),
|
model: "gpt-3.5-turbo".to_string(),
|
||||||
prompt: StringOrArray::String("Complete this sentence".to_string()),
|
prompt: StringOrArray::String("Complete this sentence".to_string()),
|
||||||
max_tokens: None,
|
..default_completion_request()
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
suffix: None,
|
|
||||||
other: serde_json::Map::new(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -672,23 +809,7 @@ mod tests {
|
|||||||
"First prompt".to_string(),
|
"First prompt".to_string(),
|
||||||
"Second prompt".to_string(),
|
"Second prompt".to_string(),
|
||||||
]),
|
]),
|
||||||
max_tokens: None,
|
..default_completion_request()
|
||||||
temperature: None,
|
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
best_of: None,
|
|
||||||
logit_bias: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
suffix: None,
|
|
||||||
other: serde_json::Map::new(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -727,7 +848,7 @@ mod tests {
|
|||||||
user: Some("user123".to_string()),
|
user: Some("user123".to_string()),
|
||||||
seed: Some(42),
|
seed: Some(42),
|
||||||
suffix: Some("...".to_string()),
|
suffix: Some("...".to_string()),
|
||||||
other: serde_json::Map::new(),
|
..default_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -771,7 +892,7 @@ mod tests {
|
|||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
other: serde_json::Map::new(),
|
..default_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -803,7 +924,7 @@ mod tests {
|
|||||||
user: None,
|
user: None,
|
||||||
seed: None,
|
seed: None,
|
||||||
suffix: None,
|
suffix: None,
|
||||||
other: serde_json::Map::new(),
|
..default_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -834,27 +955,7 @@ mod tests {
|
|||||||
let req = ChatCompletionRequest {
|
let req = ChatCompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
model: "gpt-4".to_string(),
|
model: "gpt-4".to_string(),
|
||||||
temperature: None,
|
..default_chat_completion_request()
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
max_tokens: None,
|
|
||||||
max_completion_tokens: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
response_format: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: None,
|
|
||||||
functions: None,
|
|
||||||
function_call: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -883,7 +984,7 @@ mod tests {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let mut logit_bias = HashMap::new();
|
let mut logit_bias = HashMap::new();
|
||||||
logit_bias.insert("50256".to_string(), -100);
|
logit_bias.insert("50256".to_string(), -100.0f32);
|
||||||
|
|
||||||
let tool = Tool {
|
let tool = Tool {
|
||||||
tool_type: "function".to_string(),
|
tool_type: "function".to_string(),
|
||||||
@@ -920,6 +1021,7 @@ mod tests {
|
|||||||
parallel_tool_calls: Some(false),
|
parallel_tool_calls: Some(false),
|
||||||
functions: None,
|
functions: None,
|
||||||
function_call: None,
|
function_call: None,
|
||||||
|
..default_chat_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -968,27 +1070,7 @@ mod tests {
|
|||||||
let req = ChatCompletionRequest {
|
let req = ChatCompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
model: "gpt-4-vision".to_string(),
|
model: "gpt-4-vision".to_string(),
|
||||||
temperature: None,
|
..default_chat_completion_request()
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
max_tokens: None,
|
|
||||||
max_completion_tokens: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
response_format: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: None,
|
|
||||||
functions: None,
|
|
||||||
function_call: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1037,6 +1119,7 @@ mod tests {
|
|||||||
parallel_tool_calls: None,
|
parallel_tool_calls: None,
|
||||||
functions: None,
|
functions: None,
|
||||||
function_call: None,
|
function_call: None,
|
||||||
|
..default_chat_completion_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1054,32 +1137,13 @@ mod tests {
|
|||||||
name: None,
|
name: None,
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
function_call: None,
|
function_call: None,
|
||||||
|
reasoning_content: None,
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let req = ChatCompletionRequest {
|
let req = ChatCompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
model: "gpt-3.5-turbo".to_string(),
|
model: "gpt-3.5-turbo".to_string(),
|
||||||
temperature: None,
|
..default_chat_completion_request()
|
||||||
top_p: None,
|
|
||||||
n: None,
|
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
max_tokens: None,
|
|
||||||
max_completion_tokens: None,
|
|
||||||
presence_penalty: None,
|
|
||||||
frequency_penalty: None,
|
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
response_format: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: None,
|
|
||||||
functions: None,
|
|
||||||
function_call: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1101,12 +1165,7 @@ mod tests {
|
|||||||
fn test_routeable_request_to_json() {
|
fn test_routeable_request_to_json() {
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: Some("test".to_string()),
|
text: Some("test".to_string()),
|
||||||
prompt: None,
|
..default_generate_request()
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let json = req.to_json().unwrap();
|
let json = req.to_json().unwrap();
|
||||||
@@ -1166,6 +1225,7 @@ mod tests {
|
|||||||
parameters: Some(params),
|
parameters: Some(params),
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1187,6 +1247,7 @@ mod tests {
|
|||||||
parameters: None,
|
parameters: None,
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1206,12 +1267,7 @@ mod tests {
|
|||||||
|
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: Some(unicode_text.clone()),
|
text: Some(unicode_text.clone()),
|
||||||
prompt: None,
|
..default_generate_request()
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1250,6 +1306,7 @@ mod tests {
|
|||||||
parameters: Some(params),
|
parameters: Some(params),
|
||||||
sampling_params: None,
|
sampling_params: None,
|
||||||
return_logprob: false,
|
return_logprob: false,
|
||||||
|
..default_generate_request()
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1265,12 +1322,7 @@ mod tests {
|
|||||||
fn test_bootstrap_fields_none() {
|
fn test_bootstrap_fields_none() {
|
||||||
let req = GenerateRequest {
|
let req = GenerateRequest {
|
||||||
text: Some("test".to_string()),
|
text: Some("test".to_string()),
|
||||||
prompt: None,
|
..default_generate_request()
|
||||||
input_ids: None,
|
|
||||||
stream: false,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pd_req = req.to_pd_request();
|
let pd_req = req.to_pd_request();
|
||||||
@@ -1279,4 +1331,182 @@ mod tests {
|
|||||||
assert_eq!(pd_req.bootstrap_port, None);
|
assert_eq!(pd_req.bootstrap_port, None);
|
||||||
assert_eq!(pd_req.bootstrap_room, None);
|
assert_eq!(pd_req.bootstrap_room, None);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============= SGLang Extension Field Pass-Through Tests =============
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_completion_sglang_extensions_passed_through() {
|
||||||
|
let messages = vec![ChatMessage::User {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: UserMessageContent::Text("Test".to_string()),
|
||||||
|
name: None,
|
||||||
|
}];
|
||||||
|
|
||||||
|
let mut session_params = std::collections::HashMap::new();
|
||||||
|
session_params.insert(
|
||||||
|
"key".to_string(),
|
||||||
|
serde_json::Value::String("value".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let req = ChatCompletionRequest {
|
||||||
|
messages,
|
||||||
|
model: "test-model".to_string(),
|
||||||
|
// SGLang Extensions - Priority 1
|
||||||
|
top_k: Some(40),
|
||||||
|
min_p: Some(0.05),
|
||||||
|
min_tokens: Some(10),
|
||||||
|
repetition_penalty: Some(1.1),
|
||||||
|
regex: Some("test_regex".to_string()),
|
||||||
|
ebnf: Some("test_ebnf".to_string()),
|
||||||
|
stop_token_ids: Some(vec![1, 2, 3]),
|
||||||
|
// SGLang Extensions - Priority 2
|
||||||
|
lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))),
|
||||||
|
session_params: Some(session_params.clone()),
|
||||||
|
// Boolean extensions (ChatCompletionRequest has these as bool, not Option<bool>)
|
||||||
|
no_stop_trim: true,
|
||||||
|
ignore_eos: false,
|
||||||
|
continue_final_message: true,
|
||||||
|
skip_special_tokens: false,
|
||||||
|
separate_reasoning: true,
|
||||||
|
stream_reasoning: false,
|
||||||
|
return_hidden_states: true,
|
||||||
|
..default_chat_completion_request()
|
||||||
|
};
|
||||||
|
|
||||||
|
let pd_req = req.to_pd_request();
|
||||||
|
let other = pd_req.other.as_object().unwrap();
|
||||||
|
|
||||||
|
// Verify SGLang extensions are passed through
|
||||||
|
assert_eq!(other.get("top_k"), Some(&json!(40)));
|
||||||
|
assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001);
|
||||||
|
assert_eq!(other.get("min_tokens"), Some(&json!(10)));
|
||||||
|
assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001);
|
||||||
|
assert_eq!(other.get("regex"), Some(&json!("test_regex")));
|
||||||
|
assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf")));
|
||||||
|
assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3])));
|
||||||
|
assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin")));
|
||||||
|
assert_eq!(
|
||||||
|
other.get("session_params"),
|
||||||
|
Some(&serde_json::to_value(&session_params).unwrap())
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify boolean extensions
|
||||||
|
assert_eq!(other.get("no_stop_trim"), Some(&json!(true)));
|
||||||
|
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("continue_final_message"), Some(&json!(true)));
|
||||||
|
assert_eq!(other.get("skip_special_tokens"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("separate_reasoning"), Some(&json!(true)));
|
||||||
|
assert_eq!(other.get("stream_reasoning"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("return_hidden_states"), Some(&json!(true)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_completion_request_sglang_extensions_passed_through() {
|
||||||
|
let mut session_params = std::collections::HashMap::new();
|
||||||
|
session_params.insert(
|
||||||
|
"key".to_string(),
|
||||||
|
serde_json::Value::String("value".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let req = CompletionRequest {
|
||||||
|
prompt: StringOrArray::String("Test prompt".to_string()),
|
||||||
|
model: "test-model".to_string(),
|
||||||
|
// SGLang Extensions - Priority 1
|
||||||
|
top_k: Some(40),
|
||||||
|
min_p: Some(0.05),
|
||||||
|
min_tokens: Some(10),
|
||||||
|
repetition_penalty: Some(1.1),
|
||||||
|
regex: Some("test_regex".to_string()),
|
||||||
|
ebnf: Some("test_ebnf".to_string()),
|
||||||
|
stop_token_ids: Some(vec![1, 2, 3]),
|
||||||
|
// SGLang Extensions - Priority 2
|
||||||
|
lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))),
|
||||||
|
session_params: Some(session_params.clone()),
|
||||||
|
// Boolean extensions (CompletionRequest only has these 4 boolean fields)
|
||||||
|
no_stop_trim: true,
|
||||||
|
ignore_eos: false,
|
||||||
|
skip_special_tokens: false,
|
||||||
|
return_hidden_states: true,
|
||||||
|
..default_completion_request()
|
||||||
|
};
|
||||||
|
|
||||||
|
let pd_req = req.to_pd_request();
|
||||||
|
let other = pd_req.other.as_object().unwrap();
|
||||||
|
|
||||||
|
// Verify SGLang extensions are passed through
|
||||||
|
assert_eq!(other.get("top_k"), Some(&json!(40)));
|
||||||
|
assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001);
|
||||||
|
assert_eq!(other.get("min_tokens"), Some(&json!(10)));
|
||||||
|
assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001);
|
||||||
|
assert_eq!(other.get("regex"), Some(&json!("test_regex")));
|
||||||
|
assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf")));
|
||||||
|
assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3])));
|
||||||
|
assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin")));
|
||||||
|
assert_eq!(
|
||||||
|
other.get("session_params"),
|
||||||
|
Some(&serde_json::to_value(&session_params).unwrap())
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify boolean extensions (only the ones CompletionRequest has)
|
||||||
|
assert_eq!(other.get("no_stop_trim"), Some(&json!(true)));
|
||||||
|
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("skip_special_tokens"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("return_hidden_states"), Some(&json!(true)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sglang_extensions_none_values_not_passed_through() {
|
||||||
|
let messages = vec![ChatMessage::User {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: UserMessageContent::Text("Test".to_string()),
|
||||||
|
name: None,
|
||||||
|
}];
|
||||||
|
|
||||||
|
let req = ChatCompletionRequest {
|
||||||
|
messages,
|
||||||
|
model: "test-model".to_string(),
|
||||||
|
// All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
// Boolean fields use defaults (false for most, true for some with default_true)
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
continue_final_message: false,
|
||||||
|
skip_special_tokens: true, // This has default_true
|
||||||
|
separate_reasoning: true, // This has default_true
|
||||||
|
stream_reasoning: true, // This has default_true
|
||||||
|
return_hidden_states: false,
|
||||||
|
..default_chat_completion_request()
|
||||||
|
};
|
||||||
|
|
||||||
|
let pd_req = req.to_pd_request();
|
||||||
|
let other = pd_req.other.as_object().unwrap();
|
||||||
|
|
||||||
|
// Verify None values are not included
|
||||||
|
assert!(!other.contains_key("top_k"));
|
||||||
|
assert!(!other.contains_key("min_p"));
|
||||||
|
assert!(!other.contains_key("min_tokens"));
|
||||||
|
assert!(!other.contains_key("repetition_penalty"));
|
||||||
|
assert!(!other.contains_key("regex"));
|
||||||
|
assert!(!other.contains_key("ebnf"));
|
||||||
|
assert!(!other.contains_key("stop_token_ids"));
|
||||||
|
assert!(!other.contains_key("lora_path"));
|
||||||
|
assert!(!other.contains_key("session_params"));
|
||||||
|
|
||||||
|
// Boolean fields are always present with their values (can't be None)
|
||||||
|
assert_eq!(other.get("no_stop_trim"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("ignore_eos"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("continue_final_message"), Some(&json!(false)));
|
||||||
|
assert_eq!(other.get("skip_special_tokens"), Some(&json!(true))); // default_true
|
||||||
|
assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); // default_true
|
||||||
|
assert_eq!(other.get("stream_reasoning"), Some(&json!(true))); // default_true
|
||||||
|
assert_eq!(other.get("return_hidden_states"), Some(&json!(false)));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,14 +8,118 @@ use sglang_router_rs::openai_api_types::{
|
|||||||
};
|
};
|
||||||
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest};
|
||||||
|
|
||||||
|
/// Create a default GenerateRequest for benchmarks with minimal fields set
|
||||||
|
fn default_generate_request() -> GenerateRequest {
|
||||||
|
GenerateRequest {
|
||||||
|
text: None,
|
||||||
|
prompt: None,
|
||||||
|
input_ids: None,
|
||||||
|
stream: false,
|
||||||
|
parameters: None,
|
||||||
|
sampling_params: None,
|
||||||
|
return_logprob: false,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
rid: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a default ChatCompletionRequest for benchmarks with minimal fields set
|
||||||
|
fn default_chat_completion_request() -> ChatCompletionRequest {
|
||||||
|
ChatCompletionRequest {
|
||||||
|
model: String::new(),
|
||||||
|
messages: vec![],
|
||||||
|
max_tokens: None,
|
||||||
|
max_completion_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
logit_bias: None,
|
||||||
|
logprobs: false,
|
||||||
|
top_logprobs: None,
|
||||||
|
user: None,
|
||||||
|
response_format: None,
|
||||||
|
seed: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
parallel_tool_calls: None,
|
||||||
|
function_call: None,
|
||||||
|
functions: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
continue_final_message: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
separate_reasoning: true,
|
||||||
|
stream_reasoning: true,
|
||||||
|
return_hidden_states: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a default CompletionRequest for benchmarks with minimal fields set
|
||||||
|
fn default_completion_request() -> CompletionRequest {
|
||||||
|
CompletionRequest {
|
||||||
|
model: String::new(),
|
||||||
|
prompt: StringOrArray::String(String::new()),
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
n: None,
|
||||||
|
stream: false,
|
||||||
|
stream_options: None,
|
||||||
|
logprobs: None,
|
||||||
|
echo: false,
|
||||||
|
stop: None,
|
||||||
|
presence_penalty: None,
|
||||||
|
frequency_penalty: None,
|
||||||
|
best_of: None,
|
||||||
|
logit_bias: None,
|
||||||
|
user: None,
|
||||||
|
seed: None,
|
||||||
|
// SGLang Extensions
|
||||||
|
top_k: None,
|
||||||
|
min_p: None,
|
||||||
|
min_tokens: None,
|
||||||
|
repetition_penalty: None,
|
||||||
|
regex: None,
|
||||||
|
ebnf: None,
|
||||||
|
json_schema: None,
|
||||||
|
stop_token_ids: None,
|
||||||
|
no_stop_trim: false,
|
||||||
|
ignore_eos: false,
|
||||||
|
skip_special_tokens: true,
|
||||||
|
// SGLang Extensions
|
||||||
|
lora_path: None,
|
||||||
|
session_params: None,
|
||||||
|
return_hidden_states: false,
|
||||||
|
other: serde_json::Map::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_benchmark_request_creation() {
|
fn test_benchmark_request_creation() {
|
||||||
// Ensure all benchmark request types can be created without panicking
|
// Ensure all benchmark request types can be created without panicking
|
||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Test prompt".to_string()),
|
text: Some("Test prompt".to_string()),
|
||||||
input_ids: None,
|
|
||||||
prompt: None,
|
|
||||||
parameters: Some(GenerateParameters {
|
parameters: Some(GenerateParameters {
|
||||||
max_new_tokens: Some(100),
|
max_new_tokens: Some(100),
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
@@ -33,8 +137,7 @@ fn test_benchmark_request_creation() {
|
|||||||
repetition_penalty: Some(1.0),
|
repetition_penalty: Some(1.0),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
}),
|
}),
|
||||||
stream: false,
|
..default_generate_request()
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_req = ChatCompletionRequest {
|
let chat_req = ChatCompletionRequest {
|
||||||
@@ -49,44 +152,23 @@ fn test_benchmark_request_creation() {
|
|||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.0),
|
presence_penalty: Some(0.0),
|
||||||
frequency_penalty: Some(0.0),
|
frequency_penalty: Some(0.0),
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
user: None,
|
|
||||||
response_format: None,
|
|
||||||
seed: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
function_call: None,
|
..default_chat_completion_request()
|
||||||
functions: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let completion_req = CompletionRequest {
|
let completion_req = CompletionRequest {
|
||||||
model: "test-model".to_string(),
|
model: "test-model".to_string(),
|
||||||
prompt: StringOrArray::String("Test prompt".to_string()),
|
prompt: StringOrArray::String("Test prompt".to_string()),
|
||||||
suffix: None,
|
|
||||||
max_tokens: Some(50),
|
max_tokens: Some(50),
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.0),
|
presence_penalty: Some(0.0),
|
||||||
frequency_penalty: Some(0.0),
|
frequency_penalty: Some(0.0),
|
||||||
best_of: Some(1),
|
best_of: Some(1),
|
||||||
logit_bias: None,
|
..default_completion_request()
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
other: serde_json::Map::new(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test serialization works
|
// Test serialization works
|
||||||
@@ -101,12 +183,7 @@ fn test_benchmark_serialization_roundtrip() {
|
|||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Test prompt".to_string()),
|
text: Some("Test prompt".to_string()),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Serialize and deserialize
|
// Serialize and deserialize
|
||||||
@@ -125,12 +202,7 @@ fn test_benchmark_request_adaptation() {
|
|||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Test prompt".to_string()),
|
text: Some("Test prompt".to_string()),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let chat_req = ChatCompletionRequest {
|
let chat_req = ChatCompletionRequest {
|
||||||
@@ -145,44 +217,23 @@ fn test_benchmark_request_adaptation() {
|
|||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.0),
|
presence_penalty: Some(0.0),
|
||||||
frequency_penalty: Some(0.0),
|
frequency_penalty: Some(0.0),
|
||||||
logit_bias: None,
|
|
||||||
logprobs: false,
|
|
||||||
top_logprobs: None,
|
|
||||||
user: None,
|
|
||||||
response_format: None,
|
|
||||||
seed: None,
|
|
||||||
tools: None,
|
|
||||||
tool_choice: None,
|
|
||||||
parallel_tool_calls: Some(true),
|
parallel_tool_calls: Some(true),
|
||||||
function_call: None,
|
..default_chat_completion_request()
|
||||||
functions: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let completion_req = CompletionRequest {
|
let completion_req = CompletionRequest {
|
||||||
model: "test-model".to_string(),
|
model: "test-model".to_string(),
|
||||||
prompt: StringOrArray::String("Test prompt".to_string()),
|
prompt: StringOrArray::String("Test prompt".to_string()),
|
||||||
suffix: None,
|
|
||||||
max_tokens: Some(50),
|
max_tokens: Some(50),
|
||||||
temperature: Some(0.8),
|
temperature: Some(0.8),
|
||||||
top_p: Some(1.0),
|
top_p: Some(1.0),
|
||||||
n: Some(1),
|
n: Some(1),
|
||||||
stream: false,
|
|
||||||
stream_options: None,
|
|
||||||
logprobs: None,
|
|
||||||
echo: false,
|
|
||||||
stop: None,
|
|
||||||
presence_penalty: Some(0.0),
|
presence_penalty: Some(0.0),
|
||||||
frequency_penalty: Some(0.0),
|
frequency_penalty: Some(0.0),
|
||||||
best_of: Some(1),
|
best_of: Some(1),
|
||||||
logit_bias: None,
|
..default_completion_request()
|
||||||
user: None,
|
|
||||||
seed: None,
|
|
||||||
other: serde_json::Map::new(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test PD adaptation (should not panic)
|
// Test PD adaptation (should not panic)
|
||||||
@@ -197,12 +248,7 @@ fn test_benchmark_regular_routing() {
|
|||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Test prompt".to_string()),
|
text: Some("Test prompt".to_string()),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test regular routing methods (should not panic)
|
// Test regular routing methods (should not panic)
|
||||||
@@ -217,12 +263,7 @@ fn test_benchmark_performance_baseline() {
|
|||||||
|
|
||||||
let generate_req = GenerateRequest {
|
let generate_req = GenerateRequest {
|
||||||
text: Some("Short test prompt".to_string()),
|
text: Some("Short test prompt".to_string()),
|
||||||
input_ids: None,
|
..default_generate_request()
|
||||||
prompt: None,
|
|
||||||
parameters: None,
|
|
||||||
sampling_params: None,
|
|
||||||
stream: false,
|
|
||||||
return_logprob: false,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Serialization should be fast (< 1ms for simple requests)
|
// Serialization should be fast (< 1ms for simple requests)
|
||||||
|
|||||||
Reference in New Issue
Block a user