[router] Add OpenAI backend support - core function (#10254)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use clap::{ArgAction, Parser};
|
||||
use clap::{ArgAction, Parser, ValueEnum};
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
|
||||
HealthCheckConfig, MetricsConfig, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
@@ -41,6 +41,33 @@ fn parse_prefill_args() -> Vec<(String, Option<u16>)> {
|
||||
prefill_entries
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, ValueEnum)]
|
||||
pub enum Backend {
|
||||
#[value(name = "sglang")]
|
||||
Sglang,
|
||||
#[value(name = "vllm")]
|
||||
Vllm,
|
||||
#[value(name = "trtllm")]
|
||||
Trtllm,
|
||||
#[value(name = "openai")]
|
||||
Openai,
|
||||
#[value(name = "anthropic")]
|
||||
Anthropic,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Backend {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
Backend::Sglang => "sglang",
|
||||
Backend::Vllm => "vllm",
|
||||
Backend::Trtllm => "trtllm",
|
||||
Backend::Openai => "openai",
|
||||
Backend::Anthropic => "anthropic",
|
||||
};
|
||||
write!(f, "{}", s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "sglang-router")]
|
||||
#[command(about = "SGLang Router - High-performance request distribution across worker nodes")]
|
||||
@@ -145,6 +172,10 @@ struct CliArgs {
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
|
||||
/// Backend to route requests to (sglang, vllm, trtllm, openai, anthropic)
|
||||
#[arg(long, value_enum, default_value_t = Backend::Sglang, alias = "runtime")]
|
||||
backend: Backend,
|
||||
|
||||
/// Directory to store log files
|
||||
#[arg(long)]
|
||||
log_dir: Option<String>,
|
||||
@@ -339,6 +370,11 @@ impl CliArgs {
|
||||
RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
}
|
||||
} else if matches!(self.backend, Backend::Openai) {
|
||||
// OpenAI backend mode - use worker_urls as base(s)
|
||||
RoutingMode::OpenAI {
|
||||
worker_urls: self.worker_urls.clone(),
|
||||
}
|
||||
} else if self.pd_disaggregation {
|
||||
let decode_urls = self.decode.clone();
|
||||
|
||||
@@ -409,8 +445,14 @@ impl CliArgs {
|
||||
}
|
||||
all_urls.extend(decode_urls.clone());
|
||||
}
|
||||
RoutingMode::OpenAI { .. } => {
|
||||
// For connection-mode detection, skip URLs; OpenAI forces HTTP below.
|
||||
}
|
||||
}
|
||||
let connection_mode = Self::determine_connection_mode(&all_urls);
|
||||
let connection_mode = match &mode {
|
||||
RoutingMode::OpenAI { .. } => ConnectionMode::Http,
|
||||
_ => Self::determine_connection_mode(&all_urls),
|
||||
};
|
||||
|
||||
// Build RouterConfig
|
||||
Ok(RouterConfig {
|
||||
@@ -543,16 +585,28 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Print startup info
|
||||
println!("SGLang Router starting...");
|
||||
println!("Host: {}:{}", cli_args.host, cli_args.port);
|
||||
println!(
|
||||
"Mode: {}",
|
||||
if cli_args.enable_igw {
|
||||
"IGW (Inference Gateway)"
|
||||
} else if cli_args.pd_disaggregation {
|
||||
"PD Disaggregated"
|
||||
} else {
|
||||
"Regular"
|
||||
let mode_str = if cli_args.enable_igw {
|
||||
"IGW (Inference Gateway)".to_string()
|
||||
} else if matches!(cli_args.backend, Backend::Openai) {
|
||||
"OpenAI Backend".to_string()
|
||||
} else if cli_args.pd_disaggregation {
|
||||
"PD Disaggregated".to_string()
|
||||
} else {
|
||||
format!("Regular ({})", cli_args.backend)
|
||||
};
|
||||
println!("Mode: {}", mode_str);
|
||||
|
||||
// Warn for runtimes that are parsed but not yet implemented
|
||||
match cli_args.backend {
|
||||
Backend::Vllm | Backend::Trtllm | Backend::Anthropic => {
|
||||
println!(
|
||||
"WARNING: runtime '{}' not implemented yet; falling back to regular routing. \
|
||||
Provide --worker-urls or PD flags as usual.",
|
||||
cli_args.backend
|
||||
);
|
||||
}
|
||||
);
|
||||
Backend::Sglang | Backend::Openai => {}
|
||||
}
|
||||
|
||||
if !cli_args.enable_igw {
|
||||
println!("Policy: {}", cli_args.policy);
|
||||
|
||||
Reference in New Issue
Block a user