Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
use pyo3::prelude::*;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
pub mod openai_api_types;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod prometheus;
|
||||
pub mod request_adapter;
|
||||
pub mod router;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
@@ -14,6 +18,7 @@ pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
CacheAware,
|
||||
PowerOfTwo, // Moved from PD-specific, now shared
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -39,6 +44,12 @@ struct Router {
|
||||
service_discovery_namespace: Option<String>,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
// PD mode flag
|
||||
pd_disaggregated: bool,
|
||||
// PD-specific fields (only used when pd_disaggregated is true)
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -56,7 +67,7 @@ impl Router {
|
||||
balance_rel_threshold = 1.0001,
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24),
|
||||
max_payload_size = 4 * 1024 * 1024,
|
||||
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||
verbose = false,
|
||||
log_dir = None,
|
||||
service_discovery = false,
|
||||
@@ -64,7 +75,11 @@ impl Router {
|
||||
service_discovery_port = 80,
|
||||
service_discovery_namespace = None,
|
||||
prometheus_port = None,
|
||||
prometheus_host = None
|
||||
prometheus_host = None,
|
||||
request_timeout_secs = 600, // Add configurable request timeout
|
||||
pd_disaggregated = false, // New flag for PD mode
|
||||
prefill_urls = None,
|
||||
decode_urls = None
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
@@ -87,6 +102,10 @@ impl Router {
|
||||
service_discovery_namespace: Option<String>,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
pd_disaggregated: bool,
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
@@ -109,28 +128,75 @@ impl Router {
|
||||
service_discovery_namespace,
|
||||
prometheus_port,
|
||||
prometheus_host,
|
||||
request_timeout_secs,
|
||||
pd_disaggregated,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
})
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
let policy_config = match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||
let policy_config = if self.pd_disaggregated {
|
||||
// PD mode - map PolicyType to PDSelectionPolicy
|
||||
let pd_selection_policy = match &self.policy {
|
||||
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
||||
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
|
||||
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
},
|
||||
PolicyType::RoundRobin => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
|
||||
pyo3::exceptions::PyValueError::new_err(
|
||||
"PD disaggregated mode requires prefill_urls",
|
||||
)
|
||||
})?;
|
||||
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
|
||||
pyo3::exceptions::PyValueError::new_err(
|
||||
"PD disaggregated mode requires decode_urls",
|
||||
)
|
||||
})?;
|
||||
|
||||
router::PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: pd_selection_policy,
|
||||
prefill_urls: prefill_urls.clone(),
|
||||
decode_urls: decode_urls.clone(),
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Regular mode
|
||||
match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
PolicyType::PowerOfTwo => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create service discovery config if enabled
|
||||
@@ -166,6 +232,7 @@ impl Router {
|
||||
log_dir: self.log_dir.clone(),
|
||||
service_discovery_config,
|
||||
prometheus_config,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||
|
||||
704
sgl-router/src/openai_api_types.rs
Normal file
704
sgl-router/src/openai_api_types.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
// OpenAI-compatible API types for text generation
|
||||
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
|
||||
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Common trait for all generation requests
|
||||
pub trait GenerationRequest: Send + Sync {
|
||||
/// Check if the request is for streaming
|
||||
fn is_stream(&self) -> bool;
|
||||
|
||||
/// Get the model name if specified
|
||||
fn get_model(&self) -> Option<&str>;
|
||||
|
||||
/// Extract text content for routing decisions
|
||||
fn extract_text_for_routing(&self) -> String;
|
||||
}
|
||||
|
||||
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionRequest {
|
||||
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
|
||||
pub model: String,
|
||||
|
||||
/// The prompt(s) to generate completions for
|
||||
pub prompt: StringOrArray,
|
||||
|
||||
/// The suffix that comes after a completion of inserted text
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suffix: Option<String>,
|
||||
|
||||
/// The maximum number of tokens to generate
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// An alternative to sampling with temperature (nucleus sampling)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// How many completions to generate for each prompt
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// Whether to stream back partial progress
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Include the log probabilities on the logprobs most likely tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<u32>,
|
||||
|
||||
/// Echo back the prompt in addition to the completion
|
||||
#[serde(default)]
|
||||
pub echo: bool,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Generates best_of completions server-side and returns the "best"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of: Option<u32>,
|
||||
|
||||
/// Modify the likelihood of specified tokens appearing in the completion
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logit_bias: Option<HashMap<String, f32>>,
|
||||
|
||||
/// A unique identifier representing your end-user
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// If specified, our system will make a best effort to sample deterministically
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for CompletionRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
match &self.prompt {
|
||||
StringOrArray::String(s) => s.clone(),
|
||||
StringOrArray::Array(v) => v.join(" "),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Chat Completions API (v1/chat/completions) =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
/// ID of the model to use
|
||||
pub model: String,
|
||||
|
||||
/// A list of messages comprising the conversation so far
|
||||
pub messages: Vec<ChatMessage>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// An alternative to sampling with temperature
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// How many chat completion choices to generate for each input message
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// If set, partial message deltas will be sent
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
|
||||
/// The maximum number of tokens to generate
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// An upper bound for the number of tokens that can be generated for a completion
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Modify the likelihood of specified tokens appearing in the completion
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
|
||||
/// A unique identifier representing your end-user
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// If specified, our system will make a best effort to sample deterministically
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
|
||||
/// Whether to return log probabilities of the output tokens
|
||||
#[serde(default)]
|
||||
pub logprobs: bool,
|
||||
|
||||
/// An integer between 0 and 20 specifying the number of most likely tokens to return
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// An object specifying the format that the model must output
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
|
||||
/// A list of tools the model may call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// Controls which (if any) tool is called by the model
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
|
||||
/// Whether to enable parallel function calling during tool use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
|
||||
/// Deprecated: use tools instead
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub functions: Option<Vec<Function>>,
|
||||
|
||||
/// Deprecated: use tool_choice instead
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessage {
|
||||
System {
|
||||
role: String, // "system"
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
User {
|
||||
role: String, // "user"
|
||||
content: UserMessageContent,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
Assistant {
|
||||
role: String, // "assistant"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
function_call: Option<FunctionCallResponse>,
|
||||
},
|
||||
Tool {
|
||||
role: String, // "tool"
|
||||
content: String,
|
||||
tool_call_id: String,
|
||||
},
|
||||
Function {
|
||||
role: String, // "function"
|
||||
content: String,
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum UserMessageContent {
|
||||
Text(String),
|
||||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentPart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>, // "auto", "low", or "high"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseFormat {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
#[serde(rename = "json_object")]
|
||||
JsonObject,
|
||||
#[serde(rename = "json_schema")]
|
||||
JsonSchema { json_schema: JsonSchemaFormat },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct JsonSchemaFormat {
|
||||
pub name: String,
|
||||
pub schema: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub parameters: Value, // JSON Schema
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
None,
|
||||
Auto,
|
||||
Required,
|
||||
Function {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // "function"
|
||||
function: FunctionChoice,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionChoice {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: FunctionCallResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum FunctionCall {
|
||||
None,
|
||||
Auto,
|
||||
Function { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionCallResponse {
|
||||
pub name: String,
|
||||
pub arguments: String, // JSON string
|
||||
}
|
||||
|
||||
impl GenerationRequest for ChatCompletionRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Extract text from messages for routing decisions
|
||||
self.messages
|
||||
.iter()
|
||||
.filter_map(|msg| match msg {
|
||||
ChatMessage::System { content, .. } => Some(content.clone()),
|
||||
ChatMessage::User { content, .. } => match content {
|
||||
UserMessageContent::Text(text) => Some(text.clone()),
|
||||
UserMessageContent::Parts(parts) => {
|
||||
let texts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
Some(texts.join(" "))
|
||||
}
|
||||
},
|
||||
ChatMessage::Assistant { content, .. } => content.clone(),
|
||||
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
||||
ChatMessage::Function { content, .. } => Some(content.clone()),
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Generate API (/generate) =============
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateRequest {
|
||||
/// The prompt to generate from (OpenAI style)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<StringOrArray>,
|
||||
|
||||
/// Text input - SGLang native format
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
|
||||
/// Input IDs for tokenized input
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_ids: Option<InputIds>,
|
||||
|
||||
/// Generation parameters
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
|
||||
/// Sampling parameters (sglang style)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sampling_params: Option<SamplingParams>,
|
||||
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Whether to return logprobs
|
||||
#[serde(default)]
|
||||
pub return_logprob: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputIds {
|
||||
Single(Vec<i32>),
|
||||
Batch(Vec<Vec<i32>>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct GenerateParameters {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub decoder_input_details: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub do_sample: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub return_full_text: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub truncate: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub typical_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub watermark: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct SamplingParams {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ignore_eos: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub skip_special_tokens: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub json_schema: Option<String>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for GenerateRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
// Generate requests typically don't have a model field
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Check fields in priority order: text, prompt, inputs
|
||||
if let Some(ref text) = self.text {
|
||||
return text.clone();
|
||||
}
|
||||
|
||||
if let Some(ref prompt) = self.prompt {
|
||||
return match prompt {
|
||||
StringOrArray::String(s) => s.clone(),
|
||||
StringOrArray::Array(v) => v.join(" "),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(ref input_ids) = self.input_ids {
|
||||
return match input_ids {
|
||||
InputIds::Single(ids) => ids
|
||||
.iter()
|
||||
.map(|&id| id.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
InputIds::Batch(batches) => batches
|
||||
.iter()
|
||||
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
};
|
||||
}
|
||||
|
||||
// No text input found
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Helper Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum StringOrArray {
|
||||
String(String),
|
||||
Array(Vec<String>),
|
||||
}
|
||||
|
||||
// ============= Response Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "text_completion"
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<CompletionChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionChoice {
|
||||
pub text: String,
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<LogProbs>,
|
||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct LogProbs {
|
||||
pub tokens: Vec<String>,
|
||||
pub token_logprobs: Vec<Option<f32>>,
|
||||
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
||||
pub text_offset: Vec<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "chat.completion"
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<ChatLogProbs>,
|
||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatLogProbs {
|
||||
pub content: Option<Vec<ChatLogProbsContent>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatLogProbsContent {
|
||||
pub token: String,
|
||||
pub logprob: f32,
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
pub top_logprobs: Vec<TopLogProb>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TopLogProb {
|
||||
pub token: String,
|
||||
pub logprob: f32,
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionTokensDetails {
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ============= Streaming Response Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "text_completion"
|
||||
pub created: u64,
|
||||
pub choices: Vec<CompletionStreamChoice>,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionStreamChoice {
|
||||
pub text: String,
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<LogProbs>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "chat.completion.chunk"
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub choices: Vec<ChatStreamChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatStreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatMessageDelta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<ChatLogProbs>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatMessageDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolCallDelta {
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionCallDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============= Error Response Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ErrorDetail {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub param: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<String>,
|
||||
}
|
||||
1002
sgl-router/src/pd_router.rs
Normal file
1002
sgl-router/src/pd_router.rs
Normal file
File diff suppressed because it is too large
Load Diff
245
sgl-router/src/pd_types.rs
Normal file
245
sgl-router/src/pd_types.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
// Essential PDLB types extracted for PD routing
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EngineType {
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub engine_type: EngineType,
|
||||
pub url: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl EngineInfo {
|
||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Prefill,
|
||||
url,
|
||||
bootstrap_port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_decode(url: String) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Decode,
|
||||
url,
|
||||
bootstrap_port: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_path(&self, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", self.url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", self.url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = self
|
||||
.url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PDSelectionPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
CacheAware {
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
},
|
||||
}
|
||||
// Bootstrap types from PDLB
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SingleOrBatch<T> {
|
||||
Single(T),
|
||||
Batch(Vec<T>),
|
||||
}
|
||||
|
||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||
pub type InputText = SingleOrBatch<String>;
|
||||
pub type BootstrapHost = SingleOrBatch<String>;
|
||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||
|
||||
// Bootstrap trait for request handling
|
||||
pub trait Bootstrap: Send + Sync {
|
||||
fn is_stream(&self) -> bool;
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String>;
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
);
|
||||
|
||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
|
||||
let batch_size = self.get_batch_size()?;
|
||||
if let Some(batch_size) = batch_size {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||
// Use high-quality random numbers to minimize collision risk
|
||||
BootstrapRoom::Batch(
|
||||
(0..batch_size)
|
||||
.map(|_| {
|
||||
// Combine multiple sources of randomness for better distribution
|
||||
let r1 = rand::random::<u64>();
|
||||
let r2 = rand::random::<u64>();
|
||||
r1.wrapping_add(r2.rotate_left(32))
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
} else {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||
BootstrapRoom::Single({
|
||||
// Use high-quality random number for single requests too
|
||||
let r1 = rand::random::<u64>();
|
||||
let r2 = rand::random::<u64>();
|
||||
r1.wrapping_add(r2.rotate_left(32))
|
||||
}),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct GenerateReqInput {
|
||||
pub text: Option<InputText>,
|
||||
pub input_ids: Option<InputIds>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl GenerateReqInput {
|
||||
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
if self.text.is_some() && self.input_ids.is_some() {
|
||||
return Err("Both text and input_ids are present in the request".to_string());
|
||||
}
|
||||
|
||||
// Check text batch
|
||||
if let Some(InputText::Batch(texts)) = &self.text {
|
||||
if texts.is_empty() {
|
||||
return Err("Batch text array is empty".to_string());
|
||||
}
|
||||
if texts.len() > 10000 {
|
||||
// Reasonable limit for production
|
||||
return Err(format!(
|
||||
"Batch size {} exceeds maximum allowed (10000)",
|
||||
texts.len()
|
||||
));
|
||||
}
|
||||
return Ok(Some(texts.len()));
|
||||
}
|
||||
|
||||
// Check input_ids batch
|
||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||
if ids.is_empty() {
|
||||
return Err("Batch input_ids array is empty".to_string());
|
||||
}
|
||||
if ids.len() > 10000 {
|
||||
// Reasonable limit for production
|
||||
return Err(format!(
|
||||
"Batch size {} exceeds maximum allowed (10000)",
|
||||
ids.len()
|
||||
));
|
||||
}
|
||||
// Validate each sequence is not empty
|
||||
for (i, seq) in ids.iter().enumerate() {
|
||||
if seq.is_empty() {
|
||||
return Err(format!("Input sequence at index {} is empty", i));
|
||||
}
|
||||
}
|
||||
return Ok(Some(ids.len()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bootstrap for GenerateReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
self.get_batch_size()
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ChatReqInput {
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl Bootstrap for ChatReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
// Check if 'n' parameter is present and > 1
|
||||
if let Some(n_value) = self.other.get("n") {
|
||||
if let Some(n) = n_value.as_u64() {
|
||||
if n > 1 {
|
||||
return Ok(Some(n as usize));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
264
sgl-router/src/request_adapter.rs
Normal file
264
sgl-router/src/request_adapter.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
// Request adapter to bridge OpenAI API types with PD routing requirements
|
||||
|
||||
use crate::openai_api_types::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
|
||||
};
|
||||
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Adapter trait to convert OpenAI requests to PD-compatible requests
|
||||
pub trait ToPdRequest {
|
||||
type Output: Bootstrap;
|
||||
fn to_pd_request(self) -> Self::Output;
|
||||
}
|
||||
|
||||
// Helper macro to insert optional fields into a map
|
||||
macro_rules! insert_if_some {
|
||||
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||
$(
|
||||
if let Some(value) = $field {
|
||||
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// Helper macro for simple value insertions
|
||||
macro_rules! insert_value {
|
||||
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||
$(
|
||||
$map.insert($key.to_string(), $field.into());
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// ============= Generate Request Adapter =============
|
||||
|
||||
impl ToPdRequest for GenerateRequest {
|
||||
type Output = GenerateReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
// Build the other fields first
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
|
||||
let (text, input_ids) = if let Some(text_str) = self.text {
|
||||
// SGLang native format
|
||||
(Some(SingleOrBatch::Single(text_str)), None)
|
||||
} else if let Some(prompt) = self.prompt {
|
||||
// OpenAI style prompt
|
||||
let text = match prompt {
|
||||
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||
};
|
||||
(text, None)
|
||||
} else if let Some(ids) = self.input_ids {
|
||||
// Input IDs case
|
||||
let input_ids = match ids {
|
||||
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
|
||||
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
|
||||
};
|
||||
(None, input_ids)
|
||||
} else {
|
||||
// No input provided
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Add parameters to other - handle both old and new style
|
||||
if let Some(params) = self.parameters {
|
||||
// For generate endpoint, extract max_new_tokens to top level if present
|
||||
let mut params_value = serde_json::to_value(¶ms).unwrap_or(Value::Null);
|
||||
if let Value::Object(ref mut params_map) = params_value {
|
||||
// Move max_new_tokens to top level if it exists
|
||||
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
|
||||
other.insert("max_new_tokens".to_string(), max_new_tokens);
|
||||
}
|
||||
// Move temperature to top level if it exists
|
||||
if let Some(temperature) = params_map.remove("temperature") {
|
||||
other.insert("temperature".to_string(), temperature);
|
||||
}
|
||||
}
|
||||
// Only add parameters if there are remaining fields
|
||||
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
|
||||
{
|
||||
other.insert("parameters".to_string(), params_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add sampling_params if present
|
||||
if let Some(sampling_params) = self.sampling_params {
|
||||
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
|
||||
if !params_value.is_null() {
|
||||
// Extract commonly used fields to top level
|
||||
if let Value::Object(ref params_map) = params_value {
|
||||
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
|
||||
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
|
||||
}
|
||||
if let Some(temperature) = params_map.get("temperature") {
|
||||
other.insert("temperature".to_string(), temperature.clone());
|
||||
}
|
||||
}
|
||||
other.insert("sampling_params".to_string(), params_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add other fields
|
||||
insert_value!(other,
|
||||
self.stream => "stream",
|
||||
self.return_logprob => "return_logprob"
|
||||
);
|
||||
|
||||
GenerateReqInput {
|
||||
text,
|
||||
input_ids,
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Completion Request Adapter =============
|
||||
|
||||
impl ToPdRequest for CompletionRequest {
|
||||
type Output = GenerateReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
// Convert CompletionRequest to GenerateReqInput
|
||||
let text = match self.prompt {
|
||||
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||
};
|
||||
|
||||
// Map OpenAI parameters to generate parameters
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Create parameters object
|
||||
let mut params = serde_json::Map::new();
|
||||
|
||||
// Map OpenAI fields to internal parameter names
|
||||
insert_if_some!(params,
|
||||
self.max_tokens => "max_new_tokens",
|
||||
self.temperature => "temperature",
|
||||
self.top_p => "top_p",
|
||||
self.n => "best_of",
|
||||
self.logprobs => "top_n_tokens",
|
||||
self.seed => "seed"
|
||||
);
|
||||
|
||||
// Special handling for fields that need transformation
|
||||
if let Some(presence_penalty) = self.presence_penalty {
|
||||
params.insert(
|
||||
"repetition_penalty".to_string(),
|
||||
(1.0 + presence_penalty).into(),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(stop) = self.stop {
|
||||
let stop_sequences = match stop {
|
||||
StringOrArray::String(s) => vec![s],
|
||||
StringOrArray::Array(v) => v,
|
||||
};
|
||||
params.insert("stop".to_string(), stop_sequences.into());
|
||||
}
|
||||
|
||||
if self.echo {
|
||||
params.insert("return_full_text".to_string(), true.into());
|
||||
}
|
||||
|
||||
other.insert("parameters".to_string(), Value::Object(params));
|
||||
|
||||
// Store original model and stream flag
|
||||
insert_value!(other,
|
||||
self.model => "model",
|
||||
self.stream => "stream"
|
||||
);
|
||||
|
||||
GenerateReqInput {
|
||||
text,
|
||||
input_ids: None,
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Chat Completion Request Adapter =============
|
||||
|
||||
impl ToPdRequest for ChatCompletionRequest {
|
||||
type Output = ChatReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Add required fields
|
||||
insert_if_some!(other,
|
||||
Some(&self.messages) => "messages"
|
||||
);
|
||||
|
||||
insert_value!(other,
|
||||
self.model => "model",
|
||||
self.stream => "stream"
|
||||
);
|
||||
|
||||
// Add all optional fields
|
||||
insert_if_some!(other,
|
||||
self.temperature => "temperature",
|
||||
self.top_p => "top_p",
|
||||
self.n => "n",
|
||||
self.stop => "stop",
|
||||
self.max_tokens => "max_tokens",
|
||||
self.max_completion_tokens => "max_completion_tokens",
|
||||
self.presence_penalty => "presence_penalty",
|
||||
self.frequency_penalty => "frequency_penalty",
|
||||
self.logit_bias => "logit_bias",
|
||||
self.user => "user",
|
||||
self.seed => "seed",
|
||||
self.top_logprobs => "top_logprobs",
|
||||
self.response_format => "response_format",
|
||||
self.tools => "tools",
|
||||
self.tool_choice => "tool_choice",
|
||||
self.parallel_tool_calls => "parallel_tool_calls",
|
||||
self.functions => "functions",
|
||||
self.function_call => "function_call"
|
||||
);
|
||||
|
||||
// Handle boolean logprobs flag
|
||||
if self.logprobs {
|
||||
other.insert("logprobs".to_string(), true.into());
|
||||
}
|
||||
|
||||
ChatReqInput {
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Direct routing support for regular router =============
|
||||
|
||||
/// Extension trait for routing without PD conversion
|
||||
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
|
||||
/// Convert to JSON for sending to backend
|
||||
fn to_json(&self) -> Result<Value, serde_json::Error> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Convert to bytes for legacy routing
|
||||
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
|
||||
let json = serde_json::to_vec(self)?;
|
||||
Ok(bytes::Bytes::from(json))
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteableRequest for GenerateRequest {}
|
||||
impl RouteableRequest for CompletionRequest {}
|
||||
impl RouteableRequest for ChatCompletionRequest {}
|
||||
@@ -1,10 +1,10 @@
|
||||
use crate::pd_router::PDRouter;
|
||||
use crate::pd_types::PDSelectionPolicy;
|
||||
use crate::tree::Tree;
|
||||
use ::metrics::{counter, gauge, histogram};
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
@@ -15,7 +15,7 @@ use std::time::Instant;
|
||||
use tokio;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
@@ -40,6 +40,9 @@ pub enum Router {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
PrefillDecode {
|
||||
pd_router: Arc<PDRouter>,
|
||||
},
|
||||
CacheAware {
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
@@ -133,6 +136,13 @@ pub enum PolicyConfig {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy,
|
||||
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
|
||||
decode_urls: Vec<String>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -155,10 +165,24 @@ impl Router {
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
};
|
||||
|
||||
// Wait until all workers are healthy
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
// For PrefillDecode, we need to handle workers differently
|
||||
match &policy_config {
|
||||
PolicyConfig::PrefillDecodeConfig { .. } => {
|
||||
// PD mode doesn't use the worker_urls parameter
|
||||
// We'll validate PD workers separately
|
||||
}
|
||||
_ => {
|
||||
// Wait until all workers are healthy for regular modes
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create router based on policy...
|
||||
Ok(match policy_config {
|
||||
@@ -226,7 +250,7 @@ impl Router {
|
||||
});
|
||||
|
||||
for url in &worker_urls {
|
||||
tree.lock().unwrap().insert(&"".to_string(), url);
|
||||
tree.lock().unwrap().insert("", url);
|
||||
}
|
||||
|
||||
Router::CacheAware {
|
||||
@@ -242,6 +266,26 @@ impl Router {
|
||||
_eviction_thread: Some(eviction_thread),
|
||||
}
|
||||
}
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
} => {
|
||||
// Create PDRouter instance
|
||||
let pd_router = PDRouter::new(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
selection_policy,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
)?;
|
||||
|
||||
Router::PrefillDecode {
|
||||
pd_router: Arc::new(pd_router),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,16 +295,23 @@ impl Router {
|
||||
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, return empty list since we manage workers differently
|
||||
Arc::new(RwLock::new(Vec::new()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_for_healthy_workers(
|
||||
pub fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
let sync_client = reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
@@ -323,10 +374,14 @@ impl Router {
|
||||
Ok(worker_urls.read().unwrap()[0].clone())
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't need this method as routing is handled by PDRouter
|
||||
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
pub async fn send_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
@@ -339,7 +394,11 @@ impl Router {
|
||||
// Copy all headers from original request except for /health because it does not need authorization
|
||||
if route != "/health" {
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -433,50 +492,193 @@ impl Router {
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
|
||||
// Convert body to JSON
|
||||
let json: Value = match serde_json::from_slice(body) {
|
||||
Ok(j) => j,
|
||||
Err(_) => {
|
||||
warn!("Failed to parse JSON from request body.");
|
||||
return String::new();
|
||||
pub async fn route_to_all(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
route: &str,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
// Get all worker URLs based on router type
|
||||
let worker_urls = match self {
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, route_to_all is not supported directly
|
||||
// It should be handled by PDRouter if needed
|
||||
return HttpResponse::NotImplemented()
|
||||
.body("route_to_all not implemented for PrefillDecode mode");
|
||||
}
|
||||
_ => self.get_worker_urls().read().unwrap().clone(),
|
||||
};
|
||||
|
||||
match route {
|
||||
"/generate" => {
|
||||
// For /generate, always use the "text" field.
|
||||
match json.get("text").and_then(Value::as_str) {
|
||||
Some(text) => text.to_string(),
|
||||
None => {
|
||||
warn!("No 'text' field found in request body for route /generate.");
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
// Send requests to all workers concurrently
|
||||
let mut tasks = Vec::new();
|
||||
for worker_url in &worker_urls {
|
||||
let mut request_builder = client.post(format!("{}{}", worker_url, route));
|
||||
|
||||
// Copy headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
"/v1/chat/completions" | "/v1/completions" => {
|
||||
// For these routes, try "messages", then "prompt", then "text".
|
||||
if let Some(messages) = json.get("messages") {
|
||||
serde_json::to_string(messages).unwrap_or_default()
|
||||
} else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
warn!("Failed to find 'messages', 'prompt' in request body.");
|
||||
String::new()
|
||||
}
|
||||
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
|
||||
// Wait for all responses
|
||||
let results = futures_util::future::join_all(tasks).await;
|
||||
|
||||
// Check if all succeeded
|
||||
let all_success = results.iter().all(|r| {
|
||||
r.as_ref()
|
||||
.map(|res| res.status().is_success())
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
if all_success {
|
||||
HttpResponse::Ok().body("Operation completed on all servers")
|
||||
} else {
|
||||
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all_loads(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
// For PD mode, delegate to PDRouter
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
return pd_router.get_loads(client).await;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown route: {} - defaulting to fallback string", route);
|
||||
String::new()
|
||||
// For non-PD routers, handle normally
|
||||
}
|
||||
}
|
||||
|
||||
let urls = self.get_worker_urls().read().unwrap().clone();
|
||||
let prefill_urls: Vec<String> = Vec::new();
|
||||
let decode_urls = urls;
|
||||
|
||||
// Collect loads from all servers
|
||||
let mut prefill_loads = Vec::new();
|
||||
let mut decode_loads = Vec::new();
|
||||
|
||||
// Get prefill loads
|
||||
for url in &prefill_urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
prefill_loads.push(serde_json::json!({
|
||||
"engine": format!("(Prefill@{})", url),
|
||||
"load": load as i64
|
||||
}));
|
||||
}
|
||||
|
||||
// Get decode loads
|
||||
for url in &decode_urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
decode_loads.push(serde_json::json!({
|
||||
"engine": format!("(Decode@{})", url),
|
||||
"load": load as i64
|
||||
}));
|
||||
}
|
||||
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"prefill": prefill_loads,
|
||||
"decode": decode_loads
|
||||
}))
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
pub async fn route_typed_request<
|
||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||
>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
|
||||
.body("PD routing should use specialized typed handlers"),
|
||||
_ => {
|
||||
// Handle retries like the original implementation
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
// Extract routing text directly from typed request
|
||||
let text = typed_req.extract_text_for_routing();
|
||||
let is_stream = typed_req.is_stream();
|
||||
|
||||
// Select worker based on text
|
||||
let worker_url = self.select_generate_worker_from_text(&text);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
counter!("sgl_router_retries_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Send typed request directly
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
client,
|
||||
req,
|
||||
typed_req,
|
||||
route,
|
||||
&worker_url,
|
||||
is_stream,
|
||||
)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: return Result<String, String> instead of panicking
|
||||
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
|
||||
let text = self.get_text_from_request(&body, route);
|
||||
|
||||
let worker_url = match self {
|
||||
// Helper method to select worker from text
|
||||
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||
match self {
|
||||
Router::RoundRobin {
|
||||
worker_urls,
|
||||
current_index,
|
||||
@@ -506,8 +708,6 @@ impl Router {
|
||||
balance_rel_threshold,
|
||||
..
|
||||
} => {
|
||||
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
|
||||
|
||||
let tree = tree.lock().unwrap();
|
||||
let mut running_queue = running_queue.lock().unwrap();
|
||||
|
||||
@@ -572,35 +772,48 @@ impl Router {
|
||||
|
||||
selected_url
|
||||
}
|
||||
};
|
||||
|
||||
worker_url
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't use this method
|
||||
return "PD_MODE_ERROR".to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_generate_request(
|
||||
// Send typed request directly without conversion
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
is_stream: bool,
|
||||
) -> HttpResponse {
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
let start = Instant::now();
|
||||
|
||||
// Debug: Log what we're sending
|
||||
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
|
||||
debug!("Sending request to {}: {}", route, json_str);
|
||||
}
|
||||
|
||||
let mut request_builder = client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.body(body.to_vec());
|
||||
.json(typed_req); // Use json() directly with typed request
|
||||
|
||||
// Copy all headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
||||
request_builder = request_builder.header(&name, &value);
|
||||
}
|
||||
}
|
||||
|
||||
let res = match request_builder.send().await {
|
||||
Ok(res) => res,
|
||||
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||||
Err(e) => {
|
||||
error!("Failed to send request to {}: {}", worker_url, e);
|
||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
@@ -625,6 +838,12 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
||||
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
@@ -660,70 +879,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
let worker_url = self.select_generate_worker(body, route);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.send_generate_request(client, req, body, route, &worker_url)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
let (timeout_secs, interval_secs) = match self {
|
||||
Router::Random {
|
||||
@@ -741,10 +896,17 @@ impl Router {
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't support adding workers via this method
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
};
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::new();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
@@ -774,6 +936,9 @@ impl Router {
|
||||
urls.push(worker_url.to_string());
|
||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// If cache aware, initialize the queues for the new worker
|
||||
@@ -797,7 +962,7 @@ impl Router {
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// Add worker to tree
|
||||
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
|
||||
tree.lock().unwrap().insert("", worker_url);
|
||||
}
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
@@ -850,6 +1015,10 @@ impl Router {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// if cache aware, remove the worker from the tree
|
||||
@@ -875,4 +1044,133 @@ impl Router {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
.get("load")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as isize),
|
||||
Err(e) => {
|
||||
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
debug!(
|
||||
"Worker {} returned non-success status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PD-specific wrapper methods that delegate to PDRouter
|
||||
pub async fn route_pd_health_generate(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.health_generate(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_generate_typed(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: crate::pd_types::GenerateReqInput,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router
|
||||
.route_generate(&pd_router.http_client, req, typed_req, route)
|
||||
.await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_chat_typed(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: crate::pd_types::ChatReqInput,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router
|
||||
.route_chat(&pd_router.http_client, req, typed_req, route)
|
||||
.await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_server_info(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_server_info(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_models(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_models(&pd_router.http_client, req).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.flush_cache(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_model_info(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_model_info(&pd_router.http_client, req).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::prometheus::{self, PrometheusConfig};
|
||||
use crate::request_adapter::ToPdRequest;
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use actix_web::{
|
||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
|
||||
pub struct AppState {
|
||||
router: Arc<Router>,
|
||||
client: Client,
|
||||
is_pd_mode: bool, // Add flag to track PD mode
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -28,9 +30,16 @@ impl AppState {
|
||||
client: Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Check if this is PD mode from policy config
|
||||
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
|
||||
|
||||
// Create router based on policy
|
||||
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
||||
Ok(Self { router, client })
|
||||
Ok(Self {
|
||||
router,
|
||||
client,
|
||||
is_pd_mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
|
||||
}
|
||||
|
||||
// Custom error handler for JSON payload errors.
|
||||
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||
error::ErrorPayloadTooLarge("Payload too large")
|
||||
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||
error!("JSON payload error: {:?}", err);
|
||||
match &err {
|
||||
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
||||
error!(
|
||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||
length, limit
|
||||
);
|
||||
error::ErrorPayloadTooLarge(format!(
|
||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||
length, limit
|
||||
))
|
||||
}
|
||||
error::JsonPayloadError::Overflow { limit } => {
|
||||
error!("Payload overflow: exceeds limit of {} bytes", limit);
|
||||
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
||||
}
|
||||
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
// Check if we're in PD mode
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, check health on all servers
|
||||
data.router
|
||||
.route_pd_health_generate(&data.client, &req)
|
||||
.await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, aggregate info from both prefill and decode servers
|
||||
data.router.get_pd_server_info(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode - return first server's info
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, return models from the first prefill server
|
||||
data.router.get_pd_models(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, get model info from the first prefill server
|
||||
data.router.get_pd_model_info(&data.client, &req).await
|
||||
} else {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/generate")
|
||||
.await
|
||||
async fn generate(
|
||||
req: HttpRequest,
|
||||
body: web::Json<GenerateRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/generate")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
async fn v1_chat_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
|
||||
.await
|
||||
body: web::Json<ChatCompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
async fn v1_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/completions")
|
||||
.await
|
||||
body: web::Json<CompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/completions")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/add_worker")]
|
||||
@@ -153,6 +254,25 @@ async fn remove_worker(
|
||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, flush cache on both prefill and decode servers
|
||||
data.router.route_pd_flush_cache(&data.client).await
|
||||
} else {
|
||||
// Route to all workers for cache flushing
|
||||
data.router
|
||||
.route_to_all(&data.client, "/flush_cache", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Get loads from all workers
|
||||
data.router.get_all_loads(&data.client, &req).await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
@@ -163,6 +283,7 @@ pub struct ServerConfig {
|
||||
pub log_dir: Option<String>,
|
||||
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
||||
pub prometheus_config: Option<PrometheusConfig>,
|
||||
pub request_timeout_secs: u64,
|
||||
}
|
||||
|
||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
|
||||
let client = Client::builder()
|
||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.service(add_worker)
|
||||
.service(remove_worker)
|
||||
.service(list_workers)
|
||||
// Default handler for unmatched routes.
|
||||
.service(flush_cache)
|
||||
.service(get_loads)
|
||||
.default_service(web::route().to(sink_handler))
|
||||
})
|
||||
.bind_auto_h2c((config.host, config.port))?
|
||||
|
||||
Reference in New Issue
Block a user