[router] Refactor router and policy traits with dependency injection (#7987)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Keru Yang <rukeyang@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
66
sgl-router/src/routers/factory.rs
Normal file
66
sgl-router/src/routers/factory.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
//! Factory for creating router instances
|
||||
|
||||
use super::{pd_router::PDRouter, router::Router, RouterTrait};
|
||||
use crate::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use crate::policies::PolicyFactory;
|
||||
|
||||
/// Factory for creating router instances based on configuration
|
||||
pub struct RouterFactory;
|
||||
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from configuration
|
||||
pub fn create_router(config: &RouterConfig) -> Result<Box<dyn RouterTrait>, String> {
|
||||
match &config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &config.policy, config)
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
} => Self::create_pd_router(prefill_urls, decode_urls, &config.policy, config),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a regular router with injected policy
|
||||
fn create_regular_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
router_config: &RouterConfig,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create regular router with injected policy
|
||||
let router = Router::new(
|
||||
worker_urls.to_vec(),
|
||||
policy,
|
||||
router_config.worker_startup_timeout_secs,
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a PD router with injected policy
|
||||
fn create_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
router_config: &RouterConfig,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy directly from PolicyConfig
|
||||
// All policies now support PD mode through the select_worker_pair method
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create PD router with injected policy
|
||||
let router = PDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
policy,
|
||||
router_config.worker_startup_timeout_secs,
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
}
|
||||
101
sgl-router/src/routers/mod.rs
Normal file
101
sgl-router/src/routers/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
//! Router implementations
|
||||
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use std::fmt::Debug;
|
||||
|
||||
pub mod factory;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod request_adapter;
|
||||
pub mod router;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
|
||||
/// Worker management trait for administrative operations
|
||||
///
|
||||
/// This trait is separate from RouterTrait to allow Send futures
|
||||
/// for use in service discovery and other background tasks
|
||||
#[async_trait]
|
||||
pub trait WorkerManagement: Send + Sync {
|
||||
/// Add a worker to the router
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String>;
|
||||
|
||||
/// Remove a worker from the router
|
||||
fn remove_worker(&self, worker_url: &str);
|
||||
|
||||
/// Get all worker URLs
|
||||
fn get_worker_urls(&self) -> Vec<String>;
|
||||
}
|
||||
|
||||
/// Core trait for all router implementations
|
||||
///
|
||||
/// This trait provides a unified interface for routing requests,
|
||||
/// regardless of whether it's a regular router or PD router.
|
||||
#[async_trait(?Send)]
|
||||
pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
/// Get a reference to self as Any for downcasting
|
||||
fn as_any(&self) -> &dyn std::any::Any;
|
||||
/// Route a health check request
|
||||
async fn health(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Route a health generate request
|
||||
async fn health_generate(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Get server information
|
||||
async fn get_server_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Get available models
|
||||
async fn get_models(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, client: &Client, req: &HttpRequest) -> HttpResponse;
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: &HttpRequest,
|
||||
body: serde_json::Value,
|
||||
) -> HttpResponse;
|
||||
|
||||
/// Flush cache on all workers
|
||||
async fn flush_cache(&self, client: &Client) -> HttpResponse;
|
||||
|
||||
/// Get worker loads (for monitoring)
|
||||
async fn get_worker_loads(&self, client: &Client) -> HttpResponse;
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str;
|
||||
|
||||
/// Check if this is a PD router
|
||||
fn is_pd_mode(&self) -> bool {
|
||||
self.router_type() == "pd"
|
||||
}
|
||||
|
||||
/// Server liveness check - is the server process running
|
||||
fn liveness(&self) -> HttpResponse {
|
||||
// Simple liveness check - if we can respond, we're alive
|
||||
HttpResponse::Ok().body("OK")
|
||||
}
|
||||
|
||||
/// Server readiness check - is the server ready to handle requests
|
||||
fn readiness(&self) -> HttpResponse;
|
||||
}
|
||||
1393
sgl-router/src/routers/pd_router.rs
Normal file
1393
sgl-router/src/routers/pd_router.rs
Normal file
File diff suppressed because it is too large
Load Diff
249
sgl-router/src/routers/pd_types.rs
Normal file
249
sgl-router/src/routers/pd_types.rs
Normal file
@@ -0,0 +1,249 @@
|
||||
// Essential PDLB types extracted for PD routing
|
||||
|
||||
use crate::core::{Worker, WorkerType};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// Custom error type for PD router operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum PDRouterError {
|
||||
#[error("Worker already exists: {url}")]
|
||||
WorkerAlreadyExists { url: String },
|
||||
|
||||
#[error("Worker not found: {url}")]
|
||||
WorkerNotFound { url: String },
|
||||
|
||||
#[error("Lock acquisition failed: {operation}")]
|
||||
LockError { operation: String },
|
||||
|
||||
#[error("Health check failed for worker: {url}")]
|
||||
HealthCheckFailed { url: String },
|
||||
|
||||
#[error("Invalid worker configuration: {reason}")]
|
||||
InvalidConfiguration { reason: String },
|
||||
|
||||
#[error("Network error: {message}")]
|
||||
NetworkError { message: String },
|
||||
|
||||
#[error("Timeout waiting for worker: {url}")]
|
||||
Timeout { url: String },
|
||||
}
|
||||
|
||||
// Helper functions for workers
|
||||
pub fn api_path(url: &str, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hostname(url: &str) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PDSelectionPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
CacheAware {
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
},
|
||||
}
|
||||
// Bootstrap types from PDLB
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SingleOrBatch<T> {
|
||||
Single(T),
|
||||
Batch(Vec<T>),
|
||||
}
|
||||
|
||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||
pub type InputText = SingleOrBatch<String>;
|
||||
pub type BootstrapHost = SingleOrBatch<String>;
|
||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||
|
||||
// Bootstrap trait for request handling
|
||||
pub trait Bootstrap: Send + Sync {
|
||||
fn is_stream(&self) -> bool;
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String>;
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
);
|
||||
|
||||
fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> {
|
||||
let batch_size = self.get_batch_size()?;
|
||||
|
||||
// Extract bootstrap port from prefill worker if it's a prefill type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
if let Some(batch_size) = batch_size {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec![hostname; batch_size]),
|
||||
BootstrapPort::Batch(vec![bootstrap_port; batch_size]),
|
||||
// Use high-quality random numbers to minimize collision risk
|
||||
BootstrapRoom::Batch(
|
||||
(0..batch_size)
|
||||
.map(|_| {
|
||||
// Combine multiple sources of randomness for better distribution
|
||||
let r1 = rand::random::<u64>();
|
||||
let r2 = rand::random::<u64>();
|
||||
r1.wrapping_add(r2.rotate_left(32))
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
} else {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Single(hostname),
|
||||
BootstrapPort::Single(bootstrap_port),
|
||||
BootstrapRoom::Single({
|
||||
// Use high-quality random number for single requests too
|
||||
let r1 = rand::random::<u64>();
|
||||
let r2 = rand::random::<u64>();
|
||||
r1.wrapping_add(r2.rotate_left(32))
|
||||
}),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct GenerateReqInput {
|
||||
pub text: Option<InputText>,
|
||||
pub input_ids: Option<InputIds>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl GenerateReqInput {
|
||||
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
if self.text.is_some() && self.input_ids.is_some() {
|
||||
return Err("Both text and input_ids are present in the request".to_string());
|
||||
}
|
||||
|
||||
// Check text batch
|
||||
if let Some(InputText::Batch(texts)) = &self.text {
|
||||
if texts.is_empty() {
|
||||
return Err("Batch text array is empty".to_string());
|
||||
}
|
||||
if texts.len() > 10000 {
|
||||
// Reasonable limit for production
|
||||
return Err(format!(
|
||||
"Batch size {} exceeds maximum allowed (10000)",
|
||||
texts.len()
|
||||
));
|
||||
}
|
||||
return Ok(Some(texts.len()));
|
||||
}
|
||||
|
||||
// Check input_ids batch
|
||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||
if ids.is_empty() {
|
||||
return Err("Batch input_ids array is empty".to_string());
|
||||
}
|
||||
if ids.len() > 10000 {
|
||||
// Reasonable limit for production
|
||||
return Err(format!(
|
||||
"Batch size {} exceeds maximum allowed (10000)",
|
||||
ids.len()
|
||||
));
|
||||
}
|
||||
// Validate each sequence is not empty
|
||||
for (i, seq) in ids.iter().enumerate() {
|
||||
if seq.is_empty() {
|
||||
return Err(format!("Input sequence at index {} is empty", i));
|
||||
}
|
||||
}
|
||||
return Ok(Some(ids.len()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bootstrap for GenerateReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
self.get_batch_size()
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ChatReqInput {
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl Bootstrap for ChatReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
// Check if 'n' parameter is present and > 1
|
||||
if let Some(n_value) = self.other.get("n") {
|
||||
if let Some(n) = n_value.as_u64() {
|
||||
if n > 1 {
|
||||
return Ok(Some(n as usize));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
264
sgl-router/src/routers/request_adapter.rs
Normal file
264
sgl-router/src/routers/request_adapter.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
// Request adapter to bridge OpenAI API types with PD routing requirements
|
||||
|
||||
use super::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
|
||||
use crate::openai_api_types::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
|
||||
};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Adapter trait to convert OpenAI requests to PD-compatible requests
|
||||
pub trait ToPdRequest {
|
||||
type Output: Bootstrap;
|
||||
fn to_pd_request(self) -> Self::Output;
|
||||
}
|
||||
|
||||
// Helper macro to insert optional fields into a map
|
||||
macro_rules! insert_if_some {
|
||||
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||
$(
|
||||
if let Some(value) = $field {
|
||||
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// Helper macro for simple value insertions
|
||||
macro_rules! insert_value {
|
||||
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||
$(
|
||||
$map.insert($key.to_string(), $field.into());
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// ============= Generate Request Adapter =============
|
||||
|
||||
impl ToPdRequest for GenerateRequest {
|
||||
type Output = GenerateReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
// Build the other fields first
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
|
||||
let (text, input_ids) = if let Some(text_str) = self.text {
|
||||
// SGLang native format
|
||||
(Some(SingleOrBatch::Single(text_str)), None)
|
||||
} else if let Some(prompt) = self.prompt {
|
||||
// OpenAI style prompt
|
||||
let text = match prompt {
|
||||
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||
};
|
||||
(text, None)
|
||||
} else if let Some(ids) = self.input_ids {
|
||||
// Input IDs case
|
||||
let input_ids = match ids {
|
||||
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
|
||||
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
|
||||
};
|
||||
(None, input_ids)
|
||||
} else {
|
||||
// No input provided
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Add parameters to other - handle both old and new style
|
||||
if let Some(params) = self.parameters {
|
||||
// For generate endpoint, extract max_new_tokens to top level if present
|
||||
let mut params_value = serde_json::to_value(¶ms).unwrap_or(Value::Null);
|
||||
if let Value::Object(ref mut params_map) = params_value {
|
||||
// Move max_new_tokens to top level if it exists
|
||||
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
|
||||
other.insert("max_new_tokens".to_string(), max_new_tokens);
|
||||
}
|
||||
// Move temperature to top level if it exists
|
||||
if let Some(temperature) = params_map.remove("temperature") {
|
||||
other.insert("temperature".to_string(), temperature);
|
||||
}
|
||||
}
|
||||
// Only add parameters if there are remaining fields
|
||||
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
|
||||
{
|
||||
other.insert("parameters".to_string(), params_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add sampling_params if present
|
||||
if let Some(sampling_params) = self.sampling_params {
|
||||
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
|
||||
if !params_value.is_null() {
|
||||
// Extract commonly used fields to top level
|
||||
if let Value::Object(ref params_map) = params_value {
|
||||
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
|
||||
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
|
||||
}
|
||||
if let Some(temperature) = params_map.get("temperature") {
|
||||
other.insert("temperature".to_string(), temperature.clone());
|
||||
}
|
||||
}
|
||||
other.insert("sampling_params".to_string(), params_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add other fields
|
||||
insert_value!(other,
|
||||
self.stream => "stream",
|
||||
self.return_logprob => "return_logprob"
|
||||
);
|
||||
|
||||
GenerateReqInput {
|
||||
text,
|
||||
input_ids,
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Completion Request Adapter =============
|
||||
|
||||
impl ToPdRequest for CompletionRequest {
|
||||
type Output = GenerateReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
// Convert CompletionRequest to GenerateReqInput
|
||||
let text = match self.prompt {
|
||||
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||
};
|
||||
|
||||
// Map OpenAI parameters to generate parameters
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Create parameters object
|
||||
let mut params = serde_json::Map::new();
|
||||
|
||||
// Map OpenAI fields to internal parameter names
|
||||
insert_if_some!(params,
|
||||
self.max_tokens => "max_new_tokens",
|
||||
self.temperature => "temperature",
|
||||
self.top_p => "top_p",
|
||||
self.n => "best_of",
|
||||
self.logprobs => "top_n_tokens",
|
||||
self.seed => "seed"
|
||||
);
|
||||
|
||||
// Special handling for fields that need transformation
|
||||
if let Some(presence_penalty) = self.presence_penalty {
|
||||
params.insert(
|
||||
"repetition_penalty".to_string(),
|
||||
(1.0 + presence_penalty).into(),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(stop) = self.stop {
|
||||
let stop_sequences = match stop {
|
||||
StringOrArray::String(s) => vec![s],
|
||||
StringOrArray::Array(v) => v,
|
||||
};
|
||||
params.insert("stop".to_string(), stop_sequences.into());
|
||||
}
|
||||
|
||||
if self.echo {
|
||||
params.insert("return_full_text".to_string(), true.into());
|
||||
}
|
||||
|
||||
other.insert("parameters".to_string(), Value::Object(params));
|
||||
|
||||
// Store original model and stream flag
|
||||
insert_value!(other,
|
||||
self.model => "model",
|
||||
self.stream => "stream"
|
||||
);
|
||||
|
||||
GenerateReqInput {
|
||||
text,
|
||||
input_ids: None,
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Chat Completion Request Adapter =============
|
||||
|
||||
impl ToPdRequest for ChatCompletionRequest {
|
||||
type Output = ChatReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Add required fields
|
||||
insert_if_some!(other,
|
||||
Some(&self.messages) => "messages"
|
||||
);
|
||||
|
||||
insert_value!(other,
|
||||
self.model => "model",
|
||||
self.stream => "stream"
|
||||
);
|
||||
|
||||
// Add all optional fields
|
||||
insert_if_some!(other,
|
||||
self.temperature => "temperature",
|
||||
self.top_p => "top_p",
|
||||
self.n => "n",
|
||||
self.stop => "stop",
|
||||
self.max_tokens => "max_tokens",
|
||||
self.max_completion_tokens => "max_completion_tokens",
|
||||
self.presence_penalty => "presence_penalty",
|
||||
self.frequency_penalty => "frequency_penalty",
|
||||
self.logit_bias => "logit_bias",
|
||||
self.user => "user",
|
||||
self.seed => "seed",
|
||||
self.top_logprobs => "top_logprobs",
|
||||
self.response_format => "response_format",
|
||||
self.tools => "tools",
|
||||
self.tool_choice => "tool_choice",
|
||||
self.parallel_tool_calls => "parallel_tool_calls",
|
||||
self.functions => "functions",
|
||||
self.function_call => "function_call"
|
||||
);
|
||||
|
||||
// Handle boolean logprobs flag
|
||||
if self.logprobs {
|
||||
other.insert("logprobs".to_string(), true.into());
|
||||
}
|
||||
|
||||
ChatReqInput {
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Direct routing support for regular router =============
|
||||
|
||||
/// Extension trait for routing without PD conversion
|
||||
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
|
||||
/// Convert to JSON for sending to backend
|
||||
fn to_json(&self) -> Result<Value, serde_json::Error> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Convert to bytes for legacy routing
|
||||
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
|
||||
let json = serde_json::to_vec(self)?;
|
||||
Ok(bytes::Bytes::from(json))
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteableRequest for GenerateRequest {}
|
||||
impl RouteableRequest for CompletionRequest {}
|
||||
impl RouteableRequest for ChatCompletionRequest {}
|
||||
1055
sgl-router/src/routers/router.rs
Normal file
1055
sgl-router/src/routers/router.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user