Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

This commit is contained in:
Simo Lin
2025-06-18 11:28:15 -07:00
committed by GitHub
parent 712bf9ec9b
commit 09ae5b20f3
13 changed files with 4045 additions and 187 deletions

View File

@@ -1,7 +1,11 @@
use pyo3::prelude::*;
pub mod logging;
use std::collections::HashMap;
pub mod openai_api_types;
pub mod pd_router;
pub mod pd_types;
pub mod prometheus;
pub mod request_adapter;
pub mod router;
pub mod server;
pub mod service_discovery;
@@ -14,6 +18,7 @@ pub enum PolicyType {
Random,
RoundRobin,
CacheAware,
PowerOfTwo, // Moved from PD-specific, now shared
}
#[pyclass]
@@ -39,6 +44,12 @@ struct Router {
service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
// PD mode flag
pd_disaggregated: bool,
// PD-specific fields (only used when pd_disaggregated is true)
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
}
#[pymethods]
@@ -56,7 +67,7 @@ impl Router {
balance_rel_threshold = 1.0001,
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 4 * 1024 * 1024,
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
verbose = false,
log_dir = None,
service_discovery = false,
@@ -64,7 +75,11 @@ impl Router {
service_discovery_port = 80,
service_discovery_namespace = None,
prometheus_port = None,
prometheus_host = None
prometheus_host = None,
request_timeout_secs = 600, // Add configurable request timeout
pd_disaggregated = false, // New flag for PD mode
prefill_urls = None,
decode_urls = None
))]
fn new(
worker_urls: Vec<String>,
@@ -87,6 +102,10 @@ impl Router {
service_discovery_namespace: Option<String>,
prometheus_port: Option<u16>,
prometheus_host: Option<String>,
request_timeout_secs: u64,
pd_disaggregated: bool,
prefill_urls: Option<Vec<(String, Option<u16>)>>,
decode_urls: Option<Vec<String>>,
) -> PyResult<Self> {
Ok(Router {
host,
@@ -109,28 +128,75 @@ impl Router {
service_discovery_namespace,
prometheus_port,
prometheus_host,
request_timeout_secs,
pd_disaggregated,
prefill_urls,
decode_urls,
})
}
fn start(&self) -> PyResult<()> {
let policy_config = match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
let policy_config = if self.pd_disaggregated {
// PD mode - map PolicyType to PDSelectionPolicy
let pd_selection_policy = match &self.policy {
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
},
PolicyType::RoundRobin => {
return Err(pyo3::exceptions::PyValueError::new_err(
"RoundRobin policy is not supported in PD disaggregated mode",
));
}
};
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires prefill_urls",
)
})?;
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err(
"PD disaggregated mode requires decode_urls",
)
})?;
router::PolicyConfig::PrefillDecodeConfig {
selection_policy: pd_selection_policy,
prefill_urls: prefill_urls.clone(),
decode_urls: decode_urls.clone(),
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
}
} else {
// Regular mode
match &self.policy {
PolicyType::Random => router::PolicyConfig::RandomConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
},
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
timeout_secs: self.worker_startup_timeout_secs,
interval_secs: self.worker_startup_check_interval,
cache_threshold: self.cache_threshold,
balance_abs_threshold: self.balance_abs_threshold,
balance_rel_threshold: self.balance_rel_threshold,
eviction_interval_secs: self.eviction_interval_secs,
max_tree_size: self.max_tree_size,
},
PolicyType::PowerOfTwo => {
return Err(pyo3::exceptions::PyValueError::new_err(
"PowerOfTwo policy is only supported in PD disaggregated mode",
));
}
}
};
// Create service discovery config if enabled
@@ -166,6 +232,7 @@ impl Router {
log_dir: self.log_dir.clone(),
service_discovery_config,
prometheus_config,
request_timeout_secs: self.request_timeout_secs,
})
.await
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;

View File

@@ -0,0 +1,704 @@
// OpenAI-compatible API types for text generation
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
/// Common trait for all generation requests
pub trait GenerationRequest: Send + Sync {
/// Check if the request is for streaming
fn is_stream(&self) -> bool;
/// Get the model name if specified
fn get_model(&self) -> Option<&str>;
/// Extract text content for routing decisions
fn extract_text_for_routing(&self) -> String;
}
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionRequest {
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
pub model: String,
/// The prompt(s) to generate completions for
pub prompt: StringOrArray,
/// The suffix that comes after a completion of inserted text
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature (nucleus sampling)
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many completions to generate for each prompt
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// Whether to stream back partial progress
#[serde(default)]
pub stream: bool,
/// Include the log probabilities on the logprobs most likely tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u32>,
/// Echo back the prompt in addition to the completion
#[serde(default)]
pub echo: bool,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Generates best_of completions server-side and returns the "best"
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, f32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
}
impl GenerationRequest for CompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
match &self.prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
}
}
}
// ============= Chat Completions API (v1/chat/completions) =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionRequest {
/// ID of the model to use
pub model: String,
/// A list of messages comprising the conversation so far
pub messages: Vec<ChatMessage>,
/// What sampling temperature to use, between 0 and 2
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
/// An alternative to sampling with temperature
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
/// How many chat completion choices to generate for each input message
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
/// If set, partial message deltas will be sent
#[serde(default)]
pub stream: bool,
/// Up to 4 sequences where the API will stop generating further tokens
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
/// The maximum number of tokens to generate
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// An upper bound for the number of tokens that can be generated for a completion
#[serde(skip_serializing_if = "Option::is_none")]
pub max_completion_tokens: Option<u32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
/// Modify the likelihood of specified tokens appearing in the completion
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, i32>>,
/// A unique identifier representing your end-user
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
/// If specified, our system will make a best effort to sample deterministically
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<i64>,
/// Whether to return log probabilities of the output tokens
#[serde(default)]
pub logprobs: bool,
/// An integer between 0 and 20 specifying the number of most likely tokens to return
#[serde(skip_serializing_if = "Option::is_none")]
pub top_logprobs: Option<u32>,
/// An object specifying the format that the model must output
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
/// A list of tools the model may call
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
/// Controls which (if any) tool is called by the model
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<ToolChoice>,
/// Whether to enable parallel function calling during tool use
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
/// Deprecated: use tools instead
#[serde(skip_serializing_if = "Option::is_none")]
pub functions: Option<Vec<Function>>,
/// Deprecated: use tool_choice instead
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ChatMessage {
System {
role: String, // "system"
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
User {
role: String, // "user"
content: UserMessageContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
},
Assistant {
role: String, // "assistant"
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
function_call: Option<FunctionCallResponse>,
},
Tool {
role: String, // "tool"
content: String,
tool_call_id: String,
},
Function {
role: String, // "function"
content: String,
name: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum UserMessageContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageUrl {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub detail: Option<String>, // "auto", "low", or "high"
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ResponseFormat {
#[serde(rename = "text")]
Text,
#[serde(rename = "json_object")]
JsonObject,
#[serde(rename = "json_schema")]
JsonSchema { json_schema: JsonSchemaFormat },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct JsonSchemaFormat {
pub name: String,
pub schema: Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub strict: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: Function,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: Value, // JSON Schema
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum ToolChoice {
None,
Auto,
Required,
Function {
#[serde(rename = "type")]
tool_type: String, // "function"
function: FunctionChoice,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionChoice {
pub name: String,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: String, // "function"
pub function: FunctionCallResponse,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum FunctionCall {
None,
Auto,
Function { name: String },
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallResponse {
pub name: String,
pub arguments: String, // JSON string
}
impl GenerationRequest for ChatCompletionRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
Some(&self.model)
}
fn extract_text_for_routing(&self) -> String {
// Extract text from messages for routing decisions
self.messages
.iter()
.filter_map(|msg| match msg {
ChatMessage::System { content, .. } => Some(content.clone()),
ChatMessage::User { content, .. } => match content {
UserMessageContent::Text(text) => Some(text.clone()),
UserMessageContent::Parts(parts) => {
let texts: Vec<String> = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect();
Some(texts.join(" "))
}
},
ChatMessage::Assistant { content, .. } => content.clone(),
ChatMessage::Tool { content, .. } => Some(content.clone()),
ChatMessage::Function { content, .. } => Some(content.clone()),
})
.collect::<Vec<String>>()
.join(" ")
}
}
// ============= Generate API (/generate) =============
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GenerateRequest {
/// The prompt to generate from (OpenAI style)
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<StringOrArray>,
/// Text input - SGLang native format
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
/// Input IDs for tokenized input
#[serde(skip_serializing_if = "Option::is_none")]
pub input_ids: Option<InputIds>,
/// Generation parameters
#[serde(default, skip_serializing_if = "Option::is_none")]
pub parameters: Option<GenerateParameters>,
/// Sampling parameters (sglang style)
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling_params: Option<SamplingParams>,
/// Whether to stream the response
#[serde(default)]
pub stream: bool,
/// Whether to return logprobs
#[serde(default)]
pub return_logprob: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum InputIds {
Single(Vec<i32>),
Batch(Vec<Vec<i32>>),
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct GenerateParameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub best_of: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub decoder_input_details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub do_sample: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub return_full_text: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub truncate: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub typical_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub watermark: Option<bool>,
}
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct SamplingParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_new_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<i32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub repetition_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<StringOrArray>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ignore_eos: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub skip_special_tokens: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<String>,
}
impl GenerationRequest for GenerateRequest {
fn is_stream(&self) -> bool {
self.stream
}
fn get_model(&self) -> Option<&str> {
// Generate requests typically don't have a model field
None
}
fn extract_text_for_routing(&self) -> String {
// Check fields in priority order: text, prompt, inputs
if let Some(ref text) = self.text {
return text.clone();
}
if let Some(ref prompt) = self.prompt {
return match prompt {
StringOrArray::String(s) => s.clone(),
StringOrArray::Array(v) => v.join(" "),
};
}
if let Some(ref input_ids) = self.input_ids {
return match input_ids {
InputIds::Single(ids) => ids
.iter()
.map(|&id| id.to_string())
.collect::<Vec<String>>()
.join(" "),
InputIds::Batch(batches) => batches
.iter()
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
.collect::<Vec<String>>()
.join(" "),
};
}
// No text input found
String::new()
}
}
// ============= Helper Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum StringOrArray {
String(String),
Array(Vec<String>),
}
// ============= Response Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct LogProbs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<Option<f32>>,
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
pub text_offset: Vec<u32>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String, // "chat.completion"
pub created: u64,
pub model: String,
pub choices: Vec<ChatChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatChoice {
pub index: u32,
pub message: ChatMessage,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbs {
pub content: Option<Vec<ChatLogProbsContent>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatLogProbsContent {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
pub top_logprobs: Vec<TopLogProb>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TopLogProb {
pub token: String,
pub logprob: f32,
pub bytes: Option<Vec<u8>>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionTokensDetails {
pub reasoning_tokens: Option<u32>,
}
// ============= Streaming Response Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamResponse {
pub id: String,
pub object: String, // "text_completion"
pub created: u64,
pub choices: Vec<CompletionStreamChoice>,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct CompletionStreamChoice {
pub text: String,
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<LogProbs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatCompletionStreamResponse {
pub id: String,
pub object: String, // "chat.completion.chunk"
pub created: u64,
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_fingerprint: Option<String>,
pub choices: Vec<ChatStreamChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<Usage>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatStreamChoice {
pub index: u32,
pub delta: ChatMessageDelta,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<ChatLogProbs>,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ChatMessageDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCallDelta>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ToolCallDelta {
pub index: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")]
pub tool_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<FunctionCallDelta>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct FunctionCallDelta {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ============= Error Response Types =============
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorResponse {
pub error: ErrorDetail,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ErrorDetail {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}

1002
sgl-router/src/pd_router.rs Normal file

File diff suppressed because it is too large Load Diff

245
sgl-router/src/pd_types.rs Normal file
View File

@@ -0,0 +1,245 @@
// Essential PDLB types extracted for PD routing
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone)]
pub enum EngineType {
Prefill,
Decode,
}
#[derive(Debug, Clone)]
pub struct EngineInfo {
pub engine_type: EngineType,
pub url: String,
pub bootstrap_port: Option<u16>,
}
impl EngineInfo {
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
EngineInfo {
engine_type: EngineType::Prefill,
url,
bootstrap_port,
}
}
pub fn new_decode(url: String) -> Self {
EngineInfo {
engine_type: EngineType::Decode,
url,
bootstrap_port: None,
}
}
pub fn api_path(&self, api_path: &str) -> String {
if api_path.starts_with("/") {
format!("{}{}", self.url, api_path)
} else {
format!("{}/{}", self.url, api_path)
}
}
pub fn get_hostname(&self) -> String {
// Simple hostname extraction without external dependencies
let url = self
.url
.trim_start_matches("http://")
.trim_start_matches("https://");
url.split(':').next().unwrap_or("localhost").to_string()
}
}
// PD-specific routing policies
#[derive(Debug, Clone, PartialEq)]
pub enum PDSelectionPolicy {
Random,
PowerOfTwo,
CacheAware {
cache_threshold: f32,
balance_abs_threshold: usize,
balance_rel_threshold: f32,
},
}
// Bootstrap types from PDLB
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
pub enum SingleOrBatch<T> {
Single(T),
Batch(Vec<T>),
}
pub type InputIds = SingleOrBatch<Vec<i32>>;
pub type InputText = SingleOrBatch<String>;
pub type BootstrapHost = SingleOrBatch<String>;
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
pub type BootstrapRoom = SingleOrBatch<u64>;
// Bootstrap trait for request handling
pub trait Bootstrap: Send + Sync {
fn is_stream(&self) -> bool;
fn get_batch_size(&self) -> Result<Option<usize>, String>;
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
);
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
let batch_size = self.get_batch_size()?;
if let Some(batch_size) = batch_size {
self.set_bootstrap_info(
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
// Use high-quality random numbers to minimize collision risk
BootstrapRoom::Batch(
(0..batch_size)
.map(|_| {
// Combine multiple sources of randomness for better distribution
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
})
.collect(),
),
);
} else {
self.set_bootstrap_info(
BootstrapHost::Single(prefill_info.get_hostname()),
BootstrapPort::Single(prefill_info.bootstrap_port),
BootstrapRoom::Single({
// Use high-quality random number for single requests too
let r1 = rand::random::<u64>();
let r2 = rand::random::<u64>();
r1.wrapping_add(r2.rotate_left(32))
}),
);
}
Ok(())
}
}
// Request types
#[derive(Debug, Deserialize, Serialize)]
pub struct GenerateReqInput {
pub text: Option<InputText>,
pub input_ids: Option<InputIds>,
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl GenerateReqInput {
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
if self.text.is_some() && self.input_ids.is_some() {
return Err("Both text and input_ids are present in the request".to_string());
}
// Check text batch
if let Some(InputText::Batch(texts)) = &self.text {
if texts.is_empty() {
return Err("Batch text array is empty".to_string());
}
if texts.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
texts.len()
));
}
return Ok(Some(texts.len()));
}
// Check input_ids batch
if let Some(InputIds::Batch(ids)) = &self.input_ids {
if ids.is_empty() {
return Err("Batch input_ids array is empty".to_string());
}
if ids.len() > 10000 {
// Reasonable limit for production
return Err(format!(
"Batch size {} exceeds maximum allowed (10000)",
ids.len()
));
}
// Validate each sequence is not empty
for (i, seq) in ids.iter().enumerate() {
if seq.is_empty() {
return Err(format!("Input sequence at index {} is empty", i));
}
}
return Ok(Some(ids.len()));
}
Ok(None)
}
}
impl Bootstrap for GenerateReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
self.get_batch_size()
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ChatReqInput {
#[serde(default)]
pub stream: bool,
pub bootstrap_host: Option<BootstrapHost>,
pub bootstrap_port: Option<BootstrapPort>,
pub bootstrap_room: Option<BootstrapRoom>,
#[serde(flatten)]
pub other: Value,
}
impl Bootstrap for ChatReqInput {
fn is_stream(&self) -> bool {
self.stream
}
fn get_batch_size(&self) -> Result<Option<usize>, String> {
// Check if 'n' parameter is present and > 1
if let Some(n_value) = self.other.get("n") {
if let Some(n) = n_value.as_u64() {
if n > 1 {
return Ok(Some(n as usize));
}
}
}
Ok(None)
}
fn set_bootstrap_info(
&mut self,
bootstrap_host: BootstrapHost,
bootstrap_port: BootstrapPort,
bootstrap_room: BootstrapRoom,
) {
self.bootstrap_host = Some(bootstrap_host);
self.bootstrap_port = Some(bootstrap_port);
self.bootstrap_room = Some(bootstrap_room);
}
}

View File

@@ -0,0 +1,264 @@
// Request adapter to bridge OpenAI API types with PD routing requirements
use crate::openai_api_types::{
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
};
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
use serde_json::Value;
/// Adapter trait to convert OpenAI requests to PD-compatible requests
pub trait ToPdRequest {
type Output: Bootstrap;
fn to_pd_request(self) -> Self::Output;
}
// Helper macro to insert optional fields into a map
macro_rules! insert_if_some {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
if let Some(value) = $field {
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
}
)*
};
}
// Helper macro for simple value insertions
macro_rules! insert_value {
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
$(
$map.insert($key.to_string(), $field.into());
)*
};
}
// ============= Generate Request Adapter =============
impl ToPdRequest for GenerateRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Build the other fields first
let mut other = serde_json::Map::new();
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
let (text, input_ids) = if let Some(text_str) = self.text {
// SGLang native format
(Some(SingleOrBatch::Single(text_str)), None)
} else if let Some(prompt) = self.prompt {
// OpenAI style prompt
let text = match prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
(text, None)
} else if let Some(ids) = self.input_ids {
// Input IDs case
let input_ids = match ids {
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
};
(None, input_ids)
} else {
// No input provided
(None, None)
};
// Add parameters to other - handle both old and new style
if let Some(params) = self.parameters {
// For generate endpoint, extract max_new_tokens to top level if present
let mut params_value = serde_json::to_value(&params).unwrap_or(Value::Null);
if let Value::Object(ref mut params_map) = params_value {
// Move max_new_tokens to top level if it exists
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens);
}
// Move temperature to top level if it exists
if let Some(temperature) = params_map.remove("temperature") {
other.insert("temperature".to_string(), temperature);
}
}
// Only add parameters if there are remaining fields
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
{
other.insert("parameters".to_string(), params_value);
}
}
// Add sampling_params if present
if let Some(sampling_params) = self.sampling_params {
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
if !params_value.is_null() {
// Extract commonly used fields to top level
if let Value::Object(ref params_map) = params_value {
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
}
if let Some(temperature) = params_map.get("temperature") {
other.insert("temperature".to_string(), temperature.clone());
}
}
other.insert("sampling_params".to_string(), params_value);
}
}
// Add other fields
insert_value!(other,
self.stream => "stream",
self.return_logprob => "return_logprob"
);
GenerateReqInput {
text,
input_ids,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Completion Request Adapter =============
impl ToPdRequest for CompletionRequest {
type Output = GenerateReqInput;
fn to_pd_request(self) -> Self::Output {
// Convert CompletionRequest to GenerateReqInput
let text = match self.prompt {
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
};
// Map OpenAI parameters to generate parameters
let mut other = serde_json::Map::new();
// Create parameters object
let mut params = serde_json::Map::new();
// Map OpenAI fields to internal parameter names
insert_if_some!(params,
self.max_tokens => "max_new_tokens",
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "best_of",
self.logprobs => "top_n_tokens",
self.seed => "seed"
);
// Special handling for fields that need transformation
if let Some(presence_penalty) = self.presence_penalty {
params.insert(
"repetition_penalty".to_string(),
(1.0 + presence_penalty).into(),
);
}
if let Some(stop) = self.stop {
let stop_sequences = match stop {
StringOrArray::String(s) => vec![s],
StringOrArray::Array(v) => v,
};
params.insert("stop".to_string(), stop_sequences.into());
}
if self.echo {
params.insert("return_full_text".to_string(), true.into());
}
other.insert("parameters".to_string(), Value::Object(params));
// Store original model and stream flag
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
GenerateReqInput {
text,
input_ids: None,
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Chat Completion Request Adapter =============
impl ToPdRequest for ChatCompletionRequest {
type Output = ChatReqInput;
fn to_pd_request(self) -> Self::Output {
let mut other = serde_json::Map::new();
// Add required fields
insert_if_some!(other,
Some(&self.messages) => "messages"
);
insert_value!(other,
self.model => "model",
self.stream => "stream"
);
// Add all optional fields
insert_if_some!(other,
self.temperature => "temperature",
self.top_p => "top_p",
self.n => "n",
self.stop => "stop",
self.max_tokens => "max_tokens",
self.max_completion_tokens => "max_completion_tokens",
self.presence_penalty => "presence_penalty",
self.frequency_penalty => "frequency_penalty",
self.logit_bias => "logit_bias",
self.user => "user",
self.seed => "seed",
self.top_logprobs => "top_logprobs",
self.response_format => "response_format",
self.tools => "tools",
self.tool_choice => "tool_choice",
self.parallel_tool_calls => "parallel_tool_calls",
self.functions => "functions",
self.function_call => "function_call"
);
// Handle boolean logprobs flag
if self.logprobs {
other.insert("logprobs".to_string(), true.into());
}
ChatReqInput {
stream: self.stream,
bootstrap_host: None,
bootstrap_port: None,
bootstrap_room: None,
other: Value::Object(other),
}
}
}
// ============= Direct routing support for regular router =============
/// Extension trait for routing without PD conversion
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
/// Convert to JSON for sending to backend
fn to_json(&self) -> Result<Value, serde_json::Error> {
serde_json::to_value(self)
}
/// Convert to bytes for legacy routing
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
let json = serde_json::to_vec(self)?;
Ok(bytes::Bytes::from(json))
}
}
impl RouteableRequest for GenerateRequest {}
impl RouteableRequest for CompletionRequest {}
impl RouteableRequest for ChatCompletionRequest {}

View File

@@ -1,10 +1,10 @@
use crate::pd_router::PDRouter;
use crate::pd_types::PDSelectionPolicy;
use crate::tree::Tree;
use ::metrics::{counter, gauge, histogram};
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
use actix_web::{HttpRequest, HttpResponse};
use bytes::Bytes;
use futures_util::{StreamExt, TryStreamExt};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::atomic::AtomicUsize;
@@ -15,7 +15,7 @@ use std::time::Instant;
use tokio;
use tracing::{debug, error, info, warn};
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
req.headers()
.iter()
.filter_map(|(name, value)| {
@@ -40,6 +40,9 @@ pub enum Router {
timeout_secs: u64,
interval_secs: u64,
},
PrefillDecode {
pd_router: Arc<PDRouter>,
},
CacheAware {
/*
Cache-Aware Load Balancing Router
@@ -133,6 +136,13 @@ pub enum PolicyConfig {
timeout_secs: u64,
interval_secs: u64,
},
PrefillDecodeConfig {
selection_policy: PDSelectionPolicy,
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
decode_urls: Vec<String>,
timeout_secs: u64,
interval_secs: u64,
},
}
impl Router {
@@ -155,10 +165,24 @@ impl Router {
interval_secs,
..
} => (*timeout_secs, *interval_secs),
PolicyConfig::PrefillDecodeConfig {
timeout_secs,
interval_secs,
..
} => (*timeout_secs, *interval_secs),
};
// Wait until all workers are healthy
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
// For PrefillDecode, we need to handle workers differently
match &policy_config {
PolicyConfig::PrefillDecodeConfig { .. } => {
// PD mode doesn't use the worker_urls parameter
// We'll validate PD workers separately
}
_ => {
// Wait until all workers are healthy for regular modes
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
}
}
// Create router based on policy...
Ok(match policy_config {
@@ -226,7 +250,7 @@ impl Router {
});
for url in &worker_urls {
tree.lock().unwrap().insert(&"".to_string(), url);
tree.lock().unwrap().insert("", url);
}
Router::CacheAware {
@@ -242,6 +266,26 @@ impl Router {
_eviction_thread: Some(eviction_thread),
}
}
PolicyConfig::PrefillDecodeConfig {
selection_policy,
prefill_urls,
decode_urls,
timeout_secs,
interval_secs,
} => {
// Create PDRouter instance
let pd_router = PDRouter::new(
prefill_urls,
decode_urls,
selection_policy,
timeout_secs,
interval_secs,
)?;
Router::PrefillDecode {
pd_router: Arc::new(pd_router),
}
}
})
}
@@ -251,16 +295,23 @@ impl Router {
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
Router::PrefillDecode { .. } => {
// For PD mode, return empty list since we manage workers differently
Arc::new(RwLock::new(Vec::new()))
}
}
}
fn wait_for_healthy_workers(
pub fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
interval_secs: u64,
) -> Result<(), String> {
let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::new();
let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
@@ -323,10 +374,14 @@ impl Router {
Ok(worker_urls.read().unwrap()[0].clone())
}
}
Router::PrefillDecode { .. } => {
// For PD mode, we don't need this method as routing is handled by PDRouter
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
}
}
}
async fn send_request(
pub async fn send_request(
&self,
client: &reqwest::Client,
worker_url: &str,
@@ -339,7 +394,11 @@ impl Router {
// Copy all headers from original request except for /health because it does not need authorization
if route != "/health" {
for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value);
// Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
request_builder = request_builder.header(name, value);
}
}
}
@@ -433,50 +492,193 @@ impl Router {
HttpResponse::InternalServerError().body("All retry attempts failed")
}
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
// Convert body to JSON
let json: Value = match serde_json::from_slice(body) {
Ok(j) => j,
Err(_) => {
warn!("Failed to parse JSON from request body.");
return String::new();
pub async fn route_to_all(
&self,
client: &reqwest::Client,
route: &str,
req: &HttpRequest,
) -> HttpResponse {
// Get all worker URLs based on router type
let worker_urls = match self {
Router::PrefillDecode { .. } => {
// For PD mode, route_to_all is not supported directly
// It should be handled by PDRouter if needed
return HttpResponse::NotImplemented()
.body("route_to_all not implemented for PrefillDecode mode");
}
_ => self.get_worker_urls().read().unwrap().clone(),
};
match route {
"/generate" => {
// For /generate, always use the "text" field.
match json.get("text").and_then(Value::as_str) {
Some(text) => text.to_string(),
None => {
warn!("No 'text' field found in request body for route /generate.");
String::new()
}
}
// Send requests to all workers concurrently
let mut tasks = Vec::new();
for worker_url in &worker_urls {
let mut request_builder = client.post(format!("{}{}", worker_url, route));
// Copy headers from original request
for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value);
}
"/v1/chat/completions" | "/v1/completions" => {
// For these routes, try "messages", then "prompt", then "text".
if let Some(messages) = json.get("messages") {
serde_json::to_string(messages).unwrap_or_default()
} else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
prompt.to_string()
} else {
warn!("Failed to find 'messages', 'prompt' in request body.");
String::new()
}
tasks.push(request_builder.send());
}
// Wait for all responses
let results = futures_util::future::join_all(tasks).await;
// Check if all succeeded
let all_success = results.iter().all(|r| {
r.as_ref()
.map(|res| res.status().is_success())
.unwrap_or(false)
});
if all_success {
HttpResponse::Ok().body("Operation completed on all servers")
} else {
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
}
}
pub async fn get_all_loads(
&self,
client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
// For PD mode, delegate to PDRouter
match self {
Router::PrefillDecode { pd_router } => {
return pd_router.get_loads(client).await;
}
_ => {
warn!("Unknown route: {} - defaulting to fallback string", route);
String::new()
// For non-PD routers, handle normally
}
}
let urls = self.get_worker_urls().read().unwrap().clone();
let prefill_urls: Vec<String> = Vec::new();
let decode_urls = urls;
// Collect loads from all servers
let mut prefill_loads = Vec::new();
let mut decode_loads = Vec::new();
// Get prefill loads
for url in &prefill_urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
prefill_loads.push(serde_json::json!({
"engine": format!("(Prefill@{})", url),
"load": load as i64
}));
}
// Get decode loads
for url in &decode_urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
decode_loads.push(serde_json::json!({
"engine": format!("(Decode@{})", url),
"load": load as i64
}));
}
HttpResponse::Ok().json(serde_json::json!({
"prefill": prefill_loads,
"decode": decode_loads
}))
}
// New method to route typed requests directly
pub async fn route_typed_request<
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
typed_req: &T,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
.body("PD routing should use specialized typed handlers"),
_ => {
// Handle retries like the original implementation
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
// Extract routing text directly from typed request
let text = typed_req.extract_text_for_routing();
let is_stream = typed_req.is_stream();
// Select worker based on text
let worker_url = self.select_generate_worker_from_text(&text);
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
counter!("sgl_router_retries_total", "route" => route.to_string())
.increment(1);
}
// Send typed request directly
let response = self
.send_typed_request(
client,
req,
typed_req,
route,
&worker_url,
is_stream,
)
.await;
if response.status().is_success() {
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
return response;
}
}
warn!(
"Generate request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
HttpResponse::InternalServerError().body("All retry attempts failed")
}
}
}
// TODO: return Result<String, String> instead of panicking
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
let text = self.get_text_from_request(&body, route);
let worker_url = match self {
// Helper method to select worker from text
fn select_generate_worker_from_text(&self, text: &str) -> String {
match self {
Router::RoundRobin {
worker_urls,
current_index,
@@ -506,8 +708,6 @@ impl Router {
balance_rel_threshold,
..
} => {
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
let tree = tree.lock().unwrap();
let mut running_queue = running_queue.lock().unwrap();
@@ -572,35 +772,48 @@ impl Router {
selected_url
}
};
worker_url
Router::PrefillDecode { .. } => {
// For PD mode, we don't use this method
return "PD_MODE_ERROR".to_string();
}
}
}
async fn send_generate_request(
// Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
typed_req: &T,
route: &str,
worker_url: &str,
is_stream: bool,
) -> HttpResponse {
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
.unwrap_or(false);
let start = Instant::now();
// Debug: Log what we're sending
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
debug!("Sending request to {}: {}", route, json_str);
}
let mut request_builder = client
.post(format!("{}{}", worker_url, route))
.body(body.to_vec());
.json(typed_req); // Use json() directly with typed request
// Copy all headers from original request
for (name, value) in copy_request_headers(req) {
request_builder = request_builder.header(name, value);
// Skip Content-Type and Content-Length as .json() sets them
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
request_builder = request_builder.header(&name, &value);
}
}
let res = match request_builder.send().await {
Ok(res) => res,
Err(_) => return HttpResponse::InternalServerError().finish(),
Err(e) => {
error!("Failed to send request to {}: {}", worker_url, e);
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
}
};
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
@@ -625,6 +838,12 @@ impl Router {
}
}
// Record metrics
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
.record(duration.as_secs_f64());
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
response
} else if let Router::CacheAware { running_queue, .. } = self {
let running_queue = Arc::clone(running_queue);
@@ -660,70 +879,6 @@ impl Router {
}
}
pub async fn route_generate_request(
&self,
client: &reqwest::Client,
req: &HttpRequest,
body: &Bytes,
route: &str,
) -> HttpResponse {
let start = Instant::now();
const MAX_REQUEST_RETRIES: u32 = 3;
const MAX_TOTAL_RETRIES: u32 = 6;
let mut total_retries = 0;
while total_retries < MAX_TOTAL_RETRIES {
let worker_url = self.select_generate_worker(body, route);
let mut request_retries = 0;
// Try the same worker multiple times
while request_retries < MAX_REQUEST_RETRIES {
if total_retries >= 1 {
info!("Retrying request after {} failed attempts", total_retries);
counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
}
let response = self
.send_generate_request(client, req, body, route, &worker_url)
.await;
if response.status().is_success() {
let duration = start.elapsed();
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response =
self.send_request(client, &worker_url, "/health", req).await;
if health_response.status().is_success() {
counter!("sgl_router_request_errors_total", "route" => route.to_string())
.increment(1);
return response;
}
}
warn!(
"Generate request to {} failed (attempt {}/{})",
worker_url,
request_retries + 1,
MAX_REQUEST_RETRIES
);
request_retries += 1;
total_retries += 1;
if request_retries == MAX_REQUEST_RETRIES {
warn!("Removing failed worker: {}", worker_url);
self.remove_worker(&worker_url);
break;
}
}
}
counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
HttpResponse::InternalServerError().body("All retry attempts failed")
}
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
let (timeout_secs, interval_secs) = match self {
Router::Random {
@@ -741,10 +896,17 @@ impl Router {
interval_secs,
..
} => (*timeout_secs, *interval_secs),
Router::PrefillDecode { .. } => {
// For PD mode, we don't support adding workers via this method
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
};
let start_time = std::time::Instant::now();
let client = reqwest::Client::new();
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
loop {
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
@@ -774,6 +936,9 @@ impl Router {
urls.push(worker_url.to_string());
gauge!("sgl_router_active_workers").set(urls.len() as f64);
}
Router::PrefillDecode { .. } => {
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
}
}
// If cache aware, initialize the queues for the new worker
@@ -797,7 +962,7 @@ impl Router {
.insert(worker_url.to_string(), 0);
// Add worker to tree
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
tree.lock().unwrap().insert("", worker_url);
}
return Ok(format!("Successfully added worker: {}", worker_url));
@@ -850,6 +1015,10 @@ impl Router {
return;
}
}
Router::PrefillDecode { .. } => {
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
return;
}
}
// if cache aware, remove the worker from the tree
@@ -875,4 +1044,133 @@ impl Router {
);
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
.get("load")
.and_then(|v| v.as_i64())
.map(|v| v as isize),
Err(e) => {
debug!("Failed to parse load response from {}: {}", worker_url, e);
None
}
},
Err(e) => {
debug!("Failed to read load response from {}: {}", worker_url, e);
None
}
},
Ok(res) => {
debug!(
"Worker {} returned non-success status: {}",
worker_url,
res.status()
);
None
}
Err(e) => {
debug!("Failed to get load from {}: {}", worker_url, e);
None
}
}
}
// PD-specific wrapper methods that delegate to PDRouter
pub async fn route_pd_health_generate(
&self,
_client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.health_generate(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_generate_typed(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
typed_req: crate::pd_types::GenerateReqInput,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router
.route_generate(&pd_router.http_client, req, typed_req, route)
.await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_chat_typed(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
typed_req: crate::pd_types::ChatReqInput,
route: &str,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router
.route_chat(&pd_router.http_client, req, typed_req, route)
.await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_server_info(
&self,
_client: &reqwest::Client,
_req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_server_info(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_models(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_models(&pd_router.http_client, req).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.flush_cache(&pd_router.http_client).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
pub async fn get_pd_model_info(
&self,
_client: &reqwest::Client,
req: &HttpRequest,
) -> HttpResponse {
match self {
Router::PrefillDecode { pd_router } => {
pd_router.get_model_info(&pd_router.http_client, req).await
}
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
}
}
}

View File

@@ -1,12 +1,13 @@
use crate::logging::{self, LoggingConfig};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::prometheus::{self, PrometheusConfig};
use crate::request_adapter::ToPdRequest;
use crate::router::PolicyConfig;
use crate::router::Router;
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
pub struct AppState {
router: Arc<Router>,
client: Client,
is_pd_mode: bool, // Add flag to track PD mode
}
impl AppState {
@@ -28,9 +30,16 @@ impl AppState {
client: Client,
policy_config: PolicyConfig,
) -> Result<Self, String> {
// Check if this is PD mode from policy config
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
// Create router based on policy
let router = Arc::new(Router::new(worker_urls, policy_config)?);
Ok(Self { router, client })
Ok(Self {
router,
client,
is_pd_mode,
})
}
}
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
}
// Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error::ErrorPayloadTooLarge("Payload too large")
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error!("JSON payload error: {:?}", err);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!("Payload overflow: exceeds limit of {} bytes", limit);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
}
}
#[get("/health")]
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
#[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
// Check if we're in PD mode
if data.is_pd_mode {
// For PD mode, check health on all servers
data.router
.route_pd_health_generate(&data.client, &req)
.await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
}
}
#[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
if data.is_pd_mode {
// For PD mode, aggregate info from both prefill and decode servers
data.router.get_pd_server_info(&data.client, &req).await
} else {
// Regular mode - return first server's info
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
}
}
#[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
if data.is_pd_mode {
// For PD mode, return models from the first prefill server
data.router.get_pd_models(&data.client, &req).await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
}
}
#[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
if data.is_pd_mode {
// For PD mode, get model info from the first prefill server
data.router.get_pd_model_info(&data.client, &req).await
} else {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
}
}
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/generate")
.await
async fn generate(
req: HttpRequest,
body: web::Json<GenerateRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/generate")
.await)
}
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
.await
body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
.await)
}
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/completions")
.await
body: web::Json<CompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/completions")
.await)
}
}
#[post("/add_worker")]
@@ -153,6 +254,25 @@ async fn remove_worker(
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
}
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, flush cache on both prefill and decode servers
data.router.route_pd_flush_cache(&data.client).await
} else {
// Route to all workers for cache flushing
data.router
.route_to_all(&data.client, "/flush_cache", &req)
.await
}
}
#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Get loads from all workers
data.router.get_all_loads(&data.client, &req).await
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
@@ -163,6 +283,7 @@ pub struct ServerConfig {
pub log_dir: Option<String>,
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64,
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.build()
.expect("Failed to create HTTP client");
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(add_worker)
.service(remove_worker)
.service(list_workers)
// Default handler for unmatched routes.
.service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler))
})
.bind_auto_h2c((config.host, config.port))?