Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -218,15 +218,39 @@ async def get_server_info():
|
||||
)
|
||||
prefill_infos = []
|
||||
decode_infos = []
|
||||
all_internal_states = []
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
for server in chain(prefill_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
prefill_infos.append(await server_info.json())
|
||||
for server in chain(decode_servers):
|
||||
server_info = await session.get(f"{server}/get_server_info")
|
||||
decode_infos.append(await server_info.json())
|
||||
info_json = await server_info.json()
|
||||
decode_infos.append(info_json)
|
||||
# Extract internal_states from decode servers
|
||||
if "internal_states" in info_json:
|
||||
all_internal_states.extend(info_json["internal_states"])
|
||||
|
||||
return {"prefill": prefill_infos, "decode": decode_infos}
|
||||
# Return format expected by bench_one_batch_server.py
|
||||
if all_internal_states:
|
||||
return {
|
||||
"internal_states": all_internal_states,
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
else:
|
||||
# Fallback with dummy data if no internal states found
|
||||
return {
|
||||
"internal_states": [
|
||||
{
|
||||
"last_gen_throughput": 0.0,
|
||||
"avg_spec_accept_length": None,
|
||||
}
|
||||
],
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/get_model_info")
|
||||
|
||||
@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
|
||||
clap = { version = "4.4", features = ["derive"] }
|
||||
bytes = "1.8.0"
|
||||
rand = "0.8.5"
|
||||
reqwest = { version = "0.12.8", features = ["stream", "blocking"] }
|
||||
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
||||
futures-util = "0.3"
|
||||
serde_json = "1.0"
|
||||
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
||||
@@ -33,6 +33,8 @@ futures = "0.3"
|
||||
# Added for metrics
|
||||
metrics = "0.24.2"
|
||||
metrics-exporter-prometheus = "0.17.0"
|
||||
# Added for request tracing
|
||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||
[profile.release]
|
||||
lto = "thin"
|
||||
codegen-units = 1
|
||||
|
||||
@@ -31,6 +31,13 @@ class RouterArgs:
|
||||
host: str = "127.0.0.1"
|
||||
port: int = 30000
|
||||
|
||||
# PD-specific configuration
|
||||
pd_disaggregated: bool = False # Enable PD disaggregated mode
|
||||
prefill_urls: List[tuple] = dataclasses.field(
|
||||
default_factory=list
|
||||
) # List of (url, bootstrap_port)
|
||||
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||
|
||||
# Routing policy
|
||||
policy: str = "cache_aware"
|
||||
worker_startup_timeout_secs: int = 300
|
||||
@@ -40,7 +47,7 @@ class RouterArgs:
|
||||
balance_rel_threshold: float = 1.0001
|
||||
eviction_interval: int = 60
|
||||
max_tree_size: int = 2**24
|
||||
max_payload_size: int = 4 * 1024 * 1024 # 4MB
|
||||
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
|
||||
verbose: bool = False
|
||||
log_dir: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
@@ -95,8 +102,29 @@ class RouterArgs:
|
||||
f"--{prefix}policy",
|
||||
type=str,
|
||||
default=RouterArgs.policy,
|
||||
choices=["random", "round_robin", "cache_aware"],
|
||||
help="Load balancing policy to use",
|
||||
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
|
||||
)
|
||||
|
||||
# PD-specific arguments
|
||||
parser.add_argument(
|
||||
f"--{prefix}pd-disaggregated",
|
||||
action="store_true",
|
||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}prefill",
|
||||
nargs=2,
|
||||
action="append",
|
||||
metavar=("URL", "BOOTSTRAP_PORT"),
|
||||
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}decode",
|
||||
nargs=1,
|
||||
action="append",
|
||||
metavar=("URL",),
|
||||
help="Decode server URL. Can be specified multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}worker-startup-timeout-secs",
|
||||
@@ -205,11 +233,19 @@ class RouterArgs:
|
||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||
"""
|
||||
prefix = "router_" if use_router_prefix else ""
|
||||
worker_urls = args.worker_urls if args.worker_urls is not None else []
|
||||
worker_urls = getattr(args, "worker_urls", [])
|
||||
|
||||
# Parse PD URLs
|
||||
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
||||
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
||||
|
||||
return cls(
|
||||
worker_urls=worker_urls,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
policy=getattr(args, f"{prefix}policy"),
|
||||
worker_startup_timeout_secs=getattr(
|
||||
args, f"{prefix}worker_startup_timeout_secs"
|
||||
@@ -247,6 +283,46 @@ class RouterArgs:
|
||||
selector[key] = value
|
||||
return selector
|
||||
|
||||
@staticmethod
|
||||
def _parse_prefill_urls(prefill_list):
|
||||
"""Parse prefill URLs from --prefill arguments.
|
||||
|
||||
Format: --prefill URL BOOTSTRAP_PORT
|
||||
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
|
||||
"""
|
||||
if not prefill_list:
|
||||
return []
|
||||
|
||||
prefill_urls = []
|
||||
for url, bootstrap_port_str in prefill_list:
|
||||
# Handle 'none' as None
|
||||
if bootstrap_port_str.lower() == "none":
|
||||
bootstrap_port = None
|
||||
else:
|
||||
try:
|
||||
bootstrap_port = int(bootstrap_port_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||
)
|
||||
|
||||
prefill_urls.append((url, bootstrap_port))
|
||||
|
||||
return prefill_urls
|
||||
|
||||
@staticmethod
|
||||
def _parse_decode_urls(decode_list):
|
||||
"""Parse decode URLs from --decode arguments.
|
||||
|
||||
Format: --decode URL
|
||||
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||
"""
|
||||
if not decode_list:
|
||||
return []
|
||||
|
||||
# decode_list is a list of single-element lists due to nargs=1
|
||||
return [url[0] for url in decode_list]
|
||||
|
||||
|
||||
def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"""Convert policy string to PolicyType enum."""
|
||||
@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
|
||||
"random": PolicyType.Random,
|
||||
"round_robin": PolicyType.RoundRobin,
|
||||
"cache_aware": PolicyType.CacheAware,
|
||||
"power_of_two": PolicyType.PowerOfTwo,
|
||||
}
|
||||
return policy_map[policy_str]
|
||||
|
||||
@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
else:
|
||||
router_args = args
|
||||
|
||||
# Validate configuration based on mode
|
||||
if router_args.pd_disaggregated:
|
||||
# Validate PD configuration
|
||||
if not router_args.prefill_urls:
|
||||
raise ValueError("PD disaggregated mode requires --prefill")
|
||||
if not router_args.decode_urls:
|
||||
raise ValueError("PD disaggregated mode requires --decode")
|
||||
|
||||
# Create router with unified constructor
|
||||
router = Router(
|
||||
worker_urls=router_args.worker_urls,
|
||||
worker_urls=(
|
||||
router_args.worker_urls if not router_args.pd_disaggregated else []
|
||||
),
|
||||
host=router_args.host,
|
||||
port=router_args.port,
|
||||
policy=policy_from_str(router_args.policy),
|
||||
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||
prometheus_port=router_args.prometheus_port,
|
||||
prometheus_host=router_args.prometheus_host,
|
||||
pd_disaggregated=router_args.pd_disaggregated,
|
||||
prefill_urls=(
|
||||
router_args.prefill_urls if router_args.pd_disaggregated else None
|
||||
),
|
||||
decode_urls=(
|
||||
router_args.decode_urls if router_args.pd_disaggregated else None
|
||||
),
|
||||
)
|
||||
|
||||
router.start()
|
||||
@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
|
||||
multi-node setups or when you want to start workers and router separately.
|
||||
|
||||
Examples:
|
||||
# Regular mode
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
|
||||
|
||||
# PD disaggregated mode
|
||||
python -m sglang_router.launch_router --pd-disaggregated \\
|
||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||
--policy cache_aware
|
||||
|
||||
""",
|
||||
formatter_class=CustomHelpFormatter,
|
||||
|
||||
@@ -15,6 +15,7 @@ class Router:
|
||||
- PolicyType.Random: Randomly select workers
|
||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
||||
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
||||
host: Host address to bind the router server. Default: '127.0.0.1'
|
||||
port: Port number to bind the router server. Default: 3001
|
||||
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
||||
@@ -28,7 +29,7 @@ class Router:
|
||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
||||
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
||||
routing. Default: 60
|
||||
max_payload_size: Maximum payload size in bytes. Default: 4MB
|
||||
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
verbose: Enable verbose logging. Default: False
|
||||
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
||||
@@ -42,6 +43,9 @@ class Router:
|
||||
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
||||
prometheus_port: Port to expose Prometheus metrics. Default: None
|
||||
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
||||
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||
decode_urls: List of URLs for decode servers (PD mode only)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -57,7 +61,7 @@ class Router:
|
||||
balance_rel_threshold: float = 1.0001,
|
||||
eviction_interval_secs: int = 60,
|
||||
max_tree_size: int = 2**24,
|
||||
max_payload_size: int = 4 * 1024 * 1024, # 4MB
|
||||
max_payload_size: int = 256 * 1024 * 1024, # 256MB
|
||||
verbose: bool = False,
|
||||
log_dir: Optional[str] = None,
|
||||
service_discovery: bool = False,
|
||||
@@ -66,6 +70,9 @@ class Router:
|
||||
service_discovery_namespace: Optional[str] = None,
|
||||
prometheus_port: Optional[int] = None,
|
||||
prometheus_host: Optional[str] = None,
|
||||
pd_disaggregated: bool = False,
|
||||
prefill_urls: Optional[List[tuple]] = None,
|
||||
decode_urls: Optional[List[str]] = None,
|
||||
):
|
||||
if selector is None:
|
||||
selector = {}
|
||||
@@ -91,6 +98,9 @@ class Router:
|
||||
service_discovery_namespace=service_discovery_namespace,
|
||||
prometheus_port=prometheus_port,
|
||||
prometheus_host=prometheus_host,
|
||||
pd_disaggregated=pd_disaggregated,
|
||||
prefill_urls=prefill_urls,
|
||||
decode_urls=decode_urls,
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
|
||||
@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
balance_rel_threshold=1.0001,
|
||||
eviction_interval=60,
|
||||
max_tree_size=2**24,
|
||||
max_payload_size=4 * 1024 * 1024, # 4MB
|
||||
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||
verbose=False,
|
||||
log_dir=None,
|
||||
service_discovery=False,
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
# PD-specific attributes
|
||||
pd_disaggregated=False,
|
||||
prefill=None,
|
||||
decode=None,
|
||||
# Keep worker_urls for regular mode
|
||||
worker_urls=[],
|
||||
)
|
||||
|
||||
def create_router_args(self, **kwargs):
|
||||
@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
|
||||
def test_launch_router_with_empty_worker_urls(self):
|
||||
args = self.create_router_args(worker_urls=[])
|
||||
self.run_router_process(args)
|
||||
self.run_router_process(args) # Expected error
|
||||
|
||||
def test_launch_router_with_service_discovery(self):
|
||||
# Test router startup with service discovery enabled but no selectors
|
||||
@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_pd_mode_basic(self):
|
||||
"""Test basic PD router functionality without actually starting servers."""
|
||||
# This test just verifies the PD router can be created and configured
|
||||
# without actually starting it (which would require real prefill/decode servers)
|
||||
from sglang_router import Router
|
||||
from sglang_router.launch_router import RouterArgs
|
||||
from sglang_router_rs import PolicyType
|
||||
|
||||
# Test RouterArgs parsing for PD mode
|
||||
# Simulate the parsed args structure from argparse with action="append"
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
||||
prefill=[
|
||||
["http://prefill1:8080", "9000"],
|
||||
["http://prefill2:8080", "none"],
|
||||
],
|
||||
decode=[
|
||||
["http://decode1:8081"],
|
||||
["http://decode2:8081"],
|
||||
],
|
||||
worker_urls=[], # Empty for PD mode
|
||||
)
|
||||
|
||||
router_args = RouterArgs.from_cli_args(args)
|
||||
self.assertTrue(router_args.pd_disaggregated)
|
||||
self.assertEqual(router_args.policy, "power_of_two")
|
||||
self.assertEqual(len(router_args.prefill_urls), 2)
|
||||
self.assertEqual(len(router_args.decode_urls), 2)
|
||||
|
||||
# Verify the parsed URLs and bootstrap ports
|
||||
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
|
||||
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
|
||||
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
|
||||
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||
|
||||
# Test Router creation in PD mode
|
||||
router = Router(
|
||||
worker_urls=[], # Empty for PD mode
|
||||
pd_disaggregated=True,
|
||||
prefill_urls=[
|
||||
("http://prefill1:8080", 9000),
|
||||
("http://prefill2:8080", None),
|
||||
],
|
||||
decode_urls=["http://decode1:8081", "http://decode2:8081"],
|
||||
policy=PolicyType.CacheAware,
|
||||
host="127.0.0.1",
|
||||
port=3001,
|
||||
)
|
||||
self.assertIsNotNone(router)
|
||||
|
||||
def test_policy_validation(self):
|
||||
"""Test that policy validation works correctly for PD and regular modes."""
|
||||
from sglang_router.launch_router import RouterArgs, launch_router
|
||||
|
||||
# Test 1: PowerOfTwo is only valid in PD mode
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=False,
|
||||
policy="power_of_two",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 2: RoundRobin is not valid in PD mode
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
policy="round_robin",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
)
|
||||
|
||||
# Should raise error
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
launch_router(args)
|
||||
self.assertIn(
|
||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||
str(cm.exception),
|
||||
)
|
||||
|
||||
# Test 3: Valid combinations should not raise errors
|
||||
# Regular mode with RoundRobin
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=False,
|
||||
policy="round_robin",
|
||||
worker_urls=["http://localhost:8000"],
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
|
||||
# PD mode with PowerOfTwo
|
||||
args = self.create_router_args(
|
||||
pd_disaggregated=True,
|
||||
policy="power_of_two",
|
||||
prefill=[["http://prefill1:8080", "9000"]],
|
||||
decode=[["http://decode1:8081"]],
|
||||
worker_urls=[],
|
||||
)
|
||||
# This should not raise (though it may fail to connect)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
use pyo3::prelude::*;
|
||||
pub mod logging;
|
||||
use std::collections::HashMap;
|
||||
pub mod openai_api_types;
|
||||
pub mod pd_router;
|
||||
pub mod pd_types;
|
||||
pub mod prometheus;
|
||||
pub mod request_adapter;
|
||||
pub mod router;
|
||||
pub mod server;
|
||||
pub mod service_discovery;
|
||||
@@ -14,6 +18,7 @@ pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
CacheAware,
|
||||
PowerOfTwo, // Moved from PD-specific, now shared
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -39,6 +44,12 @@ struct Router {
|
||||
service_discovery_namespace: Option<String>,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
// PD mode flag
|
||||
pd_disaggregated: bool,
|
||||
// PD-specific fields (only used when pd_disaggregated is true)
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -56,7 +67,7 @@ impl Router {
|
||||
balance_rel_threshold = 1.0001,
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24),
|
||||
max_payload_size = 4 * 1024 * 1024,
|
||||
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||
verbose = false,
|
||||
log_dir = None,
|
||||
service_discovery = false,
|
||||
@@ -64,7 +75,11 @@ impl Router {
|
||||
service_discovery_port = 80,
|
||||
service_discovery_namespace = None,
|
||||
prometheus_port = None,
|
||||
prometheus_host = None
|
||||
prometheus_host = None,
|
||||
request_timeout_secs = 600, // Add configurable request timeout
|
||||
pd_disaggregated = false, // New flag for PD mode
|
||||
prefill_urls = None,
|
||||
decode_urls = None
|
||||
))]
|
||||
fn new(
|
||||
worker_urls: Vec<String>,
|
||||
@@ -87,6 +102,10 @@ impl Router {
|
||||
service_discovery_namespace: Option<String>,
|
||||
prometheus_port: Option<u16>,
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
pd_disaggregated: bool,
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
) -> PyResult<Self> {
|
||||
Ok(Router {
|
||||
host,
|
||||
@@ -109,28 +128,75 @@ impl Router {
|
||||
service_discovery_namespace,
|
||||
prometheus_port,
|
||||
prometheus_host,
|
||||
request_timeout_secs,
|
||||
pd_disaggregated,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
})
|
||||
}
|
||||
|
||||
fn start(&self) -> PyResult<()> {
|
||||
let policy_config = match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||
let policy_config = if self.pd_disaggregated {
|
||||
// PD mode - map PolicyType to PDSelectionPolicy
|
||||
let pd_selection_policy = match &self.policy {
|
||||
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
||||
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
|
||||
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
},
|
||||
PolicyType::RoundRobin => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
|
||||
pyo3::exceptions::PyValueError::new_err(
|
||||
"PD disaggregated mode requires prefill_urls",
|
||||
)
|
||||
})?;
|
||||
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
|
||||
pyo3::exceptions::PyValueError::new_err(
|
||||
"PD disaggregated mode requires decode_urls",
|
||||
)
|
||||
})?;
|
||||
|
||||
router::PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: pd_selection_policy,
|
||||
prefill_urls: prefill_urls.clone(),
|
||||
decode_urls: decode_urls.clone(),
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
}
|
||||
} else {
|
||||
// Regular mode
|
||||
match &self.policy {
|
||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
},
|
||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||
timeout_secs: self.worker_startup_timeout_secs,
|
||||
interval_secs: self.worker_startup_check_interval,
|
||||
cache_threshold: self.cache_threshold,
|
||||
balance_abs_threshold: self.balance_abs_threshold,
|
||||
balance_rel_threshold: self.balance_rel_threshold,
|
||||
eviction_interval_secs: self.eviction_interval_secs,
|
||||
max_tree_size: self.max_tree_size,
|
||||
},
|
||||
PolicyType::PowerOfTwo => {
|
||||
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Create service discovery config if enabled
|
||||
@@ -166,6 +232,7 @@ impl Router {
|
||||
log_dir: self.log_dir.clone(),
|
||||
service_discovery_config,
|
||||
prometheus_config,
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||
|
||||
704
sgl-router/src/openai_api_types.rs
Normal file
704
sgl-router/src/openai_api_types.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
// OpenAI-compatible API types for text generation
|
||||
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
|
||||
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Common trait for all generation requests
|
||||
pub trait GenerationRequest: Send + Sync {
|
||||
/// Check if the request is for streaming
|
||||
fn is_stream(&self) -> bool;
|
||||
|
||||
/// Get the model name if specified
|
||||
fn get_model(&self) -> Option<&str>;
|
||||
|
||||
/// Extract text content for routing decisions
|
||||
fn extract_text_for_routing(&self) -> String;
|
||||
}
|
||||
|
||||
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionRequest {
|
||||
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
|
||||
pub model: String,
|
||||
|
||||
/// The prompt(s) to generate completions for
|
||||
pub prompt: StringOrArray,
|
||||
|
||||
/// The suffix that comes after a completion of inserted text
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suffix: Option<String>,
|
||||
|
||||
/// The maximum number of tokens to generate
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// An alternative to sampling with temperature (nucleus sampling)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// How many completions to generate for each prompt
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// Whether to stream back partial progress
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Include the log probabilities on the logprobs most likely tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<u32>,
|
||||
|
||||
/// Echo back the prompt in addition to the completion
|
||||
#[serde(default)]
|
||||
pub echo: bool,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Generates best_of completions server-side and returns the "best"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of: Option<u32>,
|
||||
|
||||
/// Modify the likelihood of specified tokens appearing in the completion
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logit_bias: Option<HashMap<String, f32>>,
|
||||
|
||||
/// A unique identifier representing your end-user
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// If specified, our system will make a best effort to sample deterministically
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for CompletionRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
match &self.prompt {
|
||||
StringOrArray::String(s) => s.clone(),
|
||||
StringOrArray::Array(v) => v.join(" "),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Chat Completions API (v1/chat/completions) =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
/// ID of the model to use
|
||||
pub model: String,
|
||||
|
||||
/// A list of messages comprising the conversation so far
|
||||
pub messages: Vec<ChatMessage>,
|
||||
|
||||
/// What sampling temperature to use, between 0 and 2
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// An alternative to sampling with temperature
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// How many chat completion choices to generate for each input message
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
|
||||
/// If set, partial message deltas will be sent
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Up to 4 sequences where the API will stop generating further tokens
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
|
||||
/// The maximum number of tokens to generate
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// An upper bound for the number of tokens that can be generated for a completion
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_completion_tokens: Option<u32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
|
||||
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
|
||||
/// Modify the likelihood of specified tokens appearing in the completion
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logit_bias: Option<HashMap<String, i32>>,
|
||||
|
||||
/// A unique identifier representing your end-user
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub user: Option<String>,
|
||||
|
||||
/// If specified, our system will make a best effort to sample deterministically
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<i64>,
|
||||
|
||||
/// Whether to return log probabilities of the output tokens
|
||||
#[serde(default)]
|
||||
pub logprobs: bool,
|
||||
|
||||
/// An integer between 0 and 20 specifying the number of most likely tokens to return
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_logprobs: Option<u32>,
|
||||
|
||||
/// An object specifying the format that the model must output
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
|
||||
/// A list of tools the model may call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
|
||||
/// Controls which (if any) tool is called by the model
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<ToolChoice>,
|
||||
|
||||
/// Whether to enable parallel function calling during tool use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
|
||||
/// Deprecated: use tools instead
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub functions: Option<Vec<Function>>,
|
||||
|
||||
/// Deprecated: use tool_choice instead
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ChatMessage {
|
||||
System {
|
||||
role: String, // "system"
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
User {
|
||||
role: String, // "user"
|
||||
content: UserMessageContent,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
},
|
||||
Assistant {
|
||||
role: String, // "assistant"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
function_call: Option<FunctionCallResponse>,
|
||||
},
|
||||
Tool {
|
||||
role: String, // "tool"
|
||||
content: String,
|
||||
tool_call_id: String,
|
||||
},
|
||||
Function {
|
||||
role: String, // "function"
|
||||
content: String,
|
||||
name: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum UserMessageContent {
|
||||
Text(String),
|
||||
Parts(Vec<ContentPart>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentPart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: ImageUrl },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ImageUrl {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub detail: Option<String>, // "auto", "low", or "high"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseFormat {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
#[serde(rename = "json_object")]
|
||||
JsonObject,
|
||||
#[serde(rename = "json_schema")]
|
||||
JsonSchema { json_schema: JsonSchemaFormat },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct JsonSchemaFormat {
|
||||
pub name: String,
|
||||
pub schema: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub strict: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: Function,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Function {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
pub parameters: Value, // JSON Schema
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ToolChoice {
|
||||
None,
|
||||
Auto,
|
||||
Required,
|
||||
Function {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String, // "function"
|
||||
function: FunctionChoice,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionChoice {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: FunctionCallResponse,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum FunctionCall {
|
||||
None,
|
||||
Auto,
|
||||
Function { name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionCallResponse {
|
||||
pub name: String,
|
||||
pub arguments: String, // JSON string
|
||||
}
|
||||
|
||||
impl GenerationRequest for ChatCompletionRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
Some(&self.model)
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Extract text from messages for routing decisions
|
||||
self.messages
|
||||
.iter()
|
||||
.filter_map(|msg| match msg {
|
||||
ChatMessage::System { content, .. } => Some(content.clone()),
|
||||
ChatMessage::User { content, .. } => match content {
|
||||
UserMessageContent::Text(text) => Some(text.clone()),
|
||||
UserMessageContent::Parts(parts) => {
|
||||
let texts: Vec<String> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
ContentPart::Text { text } => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect();
|
||||
Some(texts.join(" "))
|
||||
}
|
||||
},
|
||||
ChatMessage::Assistant { content, .. } => content.clone(),
|
||||
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
||||
ChatMessage::Function { content, .. } => Some(content.clone()),
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join(" ")
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Generate API (/generate) =============
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateRequest {
|
||||
/// The prompt to generate from (OpenAI style)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<StringOrArray>,
|
||||
|
||||
/// Text input - SGLang native format
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
|
||||
/// Input IDs for tokenized input
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input_ids: Option<InputIds>,
|
||||
|
||||
/// Generation parameters
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<GenerateParameters>,
|
||||
|
||||
/// Sampling parameters (sglang style)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sampling_params: Option<SamplingParams>,
|
||||
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
|
||||
/// Whether to return logprobs
|
||||
#[serde(default)]
|
||||
pub return_logprob: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum InputIds {
|
||||
Single(Vec<i32>),
|
||||
Batch(Vec<Vec<i32>>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct GenerateParameters {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub best_of: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub decoder_input_details: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub do_sample: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub return_full_text: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub truncate: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub typical_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub watermark: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||
pub struct SamplingParams {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_new_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub repetition_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<StringOrArray>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ignore_eos: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub skip_special_tokens: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub json_schema: Option<String>,
|
||||
}
|
||||
|
||||
impl GenerationRequest for GenerateRequest {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_model(&self) -> Option<&str> {
|
||||
// Generate requests typically don't have a model field
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_text_for_routing(&self) -> String {
|
||||
// Check fields in priority order: text, prompt, inputs
|
||||
if let Some(ref text) = self.text {
|
||||
return text.clone();
|
||||
}
|
||||
|
||||
if let Some(ref prompt) = self.prompt {
|
||||
return match prompt {
|
||||
StringOrArray::String(s) => s.clone(),
|
||||
StringOrArray::Array(v) => v.join(" "),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(ref input_ids) = self.input_ids {
|
||||
return match input_ids {
|
||||
InputIds::Single(ids) => ids
|
||||
.iter()
|
||||
.map(|&id| id.to_string())
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
InputIds::Batch(batches) => batches
|
||||
.iter()
|
||||
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
|
||||
.collect::<Vec<String>>()
|
||||
.join(" "),
|
||||
};
|
||||
}
|
||||
|
||||
// No text input found
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Helper Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum StringOrArray {
|
||||
String(String),
|
||||
Array(Vec<String>),
|
||||
}
|
||||
|
||||
// ============= Response Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "text_completion"
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<CompletionChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionChoice {
|
||||
pub text: String,
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<LogProbs>,
|
||||
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct LogProbs {
|
||||
pub tokens: Vec<String>,
|
||||
pub token_logprobs: Vec<Option<f32>>,
|
||||
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
||||
pub text_offset: Vec<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "chat.completion"
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChatChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<ChatLogProbs>,
|
||||
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatLogProbs {
|
||||
pub content: Option<Vec<ChatLogProbsContent>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatLogProbsContent {
|
||||
pub token: String,
|
||||
pub logprob: f32,
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
pub top_logprobs: Vec<TopLogProb>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct TopLogProb {
|
||||
pub token: String,
|
||||
pub logprob: f32,
|
||||
pub bytes: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionTokensDetails {
|
||||
pub reasoning_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// ============= Streaming Response Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "text_completion"
|
||||
pub created: u64,
|
||||
pub choices: Vec<CompletionStreamChoice>,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct CompletionStreamChoice {
|
||||
pub text: String,
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<LogProbs>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatCompletionStreamResponse {
|
||||
pub id: String,
|
||||
pub object: String, // "chat.completion.chunk"
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_fingerprint: Option<String>,
|
||||
pub choices: Vec<ChatStreamChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatStreamChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatMessageDelta,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub logprobs: Option<ChatLogProbs>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ChatMessageDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_call: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolCallDelta {
|
||||
pub index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<FunctionCallDelta>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct FunctionCallDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============= Error Response Types =============
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ErrorDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ErrorDetail {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub param: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<String>,
|
||||
}
|
||||
1002
sgl-router/src/pd_router.rs
Normal file
1002
sgl-router/src/pd_router.rs
Normal file
File diff suppressed because it is too large
Load Diff
245
sgl-router/src/pd_types.rs
Normal file
245
sgl-router/src/pd_types.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
// Essential PDLB types extracted for PD routing
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EngineType {
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EngineInfo {
|
||||
pub engine_type: EngineType,
|
||||
pub url: String,
|
||||
pub bootstrap_port: Option<u16>,
|
||||
}
|
||||
|
||||
impl EngineInfo {
|
||||
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Prefill,
|
||||
url,
|
||||
bootstrap_port,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_decode(url: String) -> Self {
|
||||
EngineInfo {
|
||||
engine_type: EngineType::Decode,
|
||||
url,
|
||||
bootstrap_port: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn api_path(&self, api_path: &str) -> String {
|
||||
if api_path.starts_with("/") {
|
||||
format!("{}{}", self.url, api_path)
|
||||
} else {
|
||||
format!("{}/{}", self.url, api_path)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_hostname(&self) -> String {
|
||||
// Simple hostname extraction without external dependencies
|
||||
let url = self
|
||||
.url
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://");
|
||||
url.split(':').next().unwrap_or("localhost").to_string()
|
||||
}
|
||||
}
|
||||
|
||||
// PD-specific routing policies
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum PDSelectionPolicy {
|
||||
Random,
|
||||
PowerOfTwo,
|
||||
CacheAware {
|
||||
cache_threshold: f32,
|
||||
balance_abs_threshold: usize,
|
||||
balance_rel_threshold: f32,
|
||||
},
|
||||
}
|
||||
// Bootstrap types from PDLB
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum SingleOrBatch<T> {
|
||||
Single(T),
|
||||
Batch(Vec<T>),
|
||||
}
|
||||
|
||||
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||
pub type InputText = SingleOrBatch<String>;
|
||||
pub type BootstrapHost = SingleOrBatch<String>;
|
||||
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||
|
||||
// Bootstrap trait for request handling
|
||||
pub trait Bootstrap: Send + Sync {
|
||||
fn is_stream(&self) -> bool;
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String>;
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
);
|
||||
|
||||
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
|
||||
let batch_size = self.get_batch_size()?;
|
||||
if let Some(batch_size) = batch_size {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||
// Use high-quality random numbers to minimize collision risk
|
||||
BootstrapRoom::Batch(
|
||||
(0..batch_size)
|
||||
.map(|_| {
|
||||
// Combine multiple sources of randomness for better distribution
|
||||
let r1 = rand::random::<u64>();
|
||||
let r2 = rand::random::<u64>();
|
||||
r1.wrapping_add(r2.rotate_left(32))
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
);
|
||||
} else {
|
||||
self.set_bootstrap_info(
|
||||
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||
BootstrapRoom::Single({
|
||||
// Use high-quality random number for single requests too
|
||||
let r1 = rand::random::<u64>();
|
||||
let r2 = rand::random::<u64>();
|
||||
r1.wrapping_add(r2.rotate_left(32))
|
||||
}),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Request types
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct GenerateReqInput {
|
||||
pub text: Option<InputText>,
|
||||
pub input_ids: Option<InputIds>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl GenerateReqInput {
|
||||
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
if self.text.is_some() && self.input_ids.is_some() {
|
||||
return Err("Both text and input_ids are present in the request".to_string());
|
||||
}
|
||||
|
||||
// Check text batch
|
||||
if let Some(InputText::Batch(texts)) = &self.text {
|
||||
if texts.is_empty() {
|
||||
return Err("Batch text array is empty".to_string());
|
||||
}
|
||||
if texts.len() > 10000 {
|
||||
// Reasonable limit for production
|
||||
return Err(format!(
|
||||
"Batch size {} exceeds maximum allowed (10000)",
|
||||
texts.len()
|
||||
));
|
||||
}
|
||||
return Ok(Some(texts.len()));
|
||||
}
|
||||
|
||||
// Check input_ids batch
|
||||
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||
if ids.is_empty() {
|
||||
return Err("Batch input_ids array is empty".to_string());
|
||||
}
|
||||
if ids.len() > 10000 {
|
||||
// Reasonable limit for production
|
||||
return Err(format!(
|
||||
"Batch size {} exceeds maximum allowed (10000)",
|
||||
ids.len()
|
||||
));
|
||||
}
|
||||
// Validate each sequence is not empty
|
||||
for (i, seq) in ids.iter().enumerate() {
|
||||
if seq.is_empty() {
|
||||
return Err(format!("Input sequence at index {} is empty", i));
|
||||
}
|
||||
}
|
||||
return Ok(Some(ids.len()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl Bootstrap for GenerateReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
self.get_batch_size()
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
pub struct ChatReqInput {
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub bootstrap_host: Option<BootstrapHost>,
|
||||
pub bootstrap_port: Option<BootstrapPort>,
|
||||
pub bootstrap_room: Option<BootstrapRoom>,
|
||||
|
||||
#[serde(flatten)]
|
||||
pub other: Value,
|
||||
}
|
||||
|
||||
impl Bootstrap for ChatReqInput {
|
||||
fn is_stream(&self) -> bool {
|
||||
self.stream
|
||||
}
|
||||
|
||||
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||
// Check if 'n' parameter is present and > 1
|
||||
if let Some(n_value) = self.other.get("n") {
|
||||
if let Some(n) = n_value.as_u64() {
|
||||
if n > 1 {
|
||||
return Ok(Some(n as usize));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn set_bootstrap_info(
|
||||
&mut self,
|
||||
bootstrap_host: BootstrapHost,
|
||||
bootstrap_port: BootstrapPort,
|
||||
bootstrap_room: BootstrapRoom,
|
||||
) {
|
||||
self.bootstrap_host = Some(bootstrap_host);
|
||||
self.bootstrap_port = Some(bootstrap_port);
|
||||
self.bootstrap_room = Some(bootstrap_room);
|
||||
}
|
||||
}
|
||||
264
sgl-router/src/request_adapter.rs
Normal file
264
sgl-router/src/request_adapter.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
// Request adapter to bridge OpenAI API types with PD routing requirements
|
||||
|
||||
use crate::openai_api_types::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
|
||||
};
|
||||
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Adapter trait to convert OpenAI requests to PD-compatible requests
|
||||
pub trait ToPdRequest {
|
||||
type Output: Bootstrap;
|
||||
fn to_pd_request(self) -> Self::Output;
|
||||
}
|
||||
|
||||
// Helper macro to insert optional fields into a map
|
||||
macro_rules! insert_if_some {
|
||||
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||
$(
|
||||
if let Some(value) = $field {
|
||||
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
|
||||
}
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// Helper macro for simple value insertions
|
||||
macro_rules! insert_value {
|
||||
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||
$(
|
||||
$map.insert($key.to_string(), $field.into());
|
||||
)*
|
||||
};
|
||||
}
|
||||
|
||||
// ============= Generate Request Adapter =============
|
||||
|
||||
impl ToPdRequest for GenerateRequest {
|
||||
type Output = GenerateReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
// Build the other fields first
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
|
||||
let (text, input_ids) = if let Some(text_str) = self.text {
|
||||
// SGLang native format
|
||||
(Some(SingleOrBatch::Single(text_str)), None)
|
||||
} else if let Some(prompt) = self.prompt {
|
||||
// OpenAI style prompt
|
||||
let text = match prompt {
|
||||
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||
};
|
||||
(text, None)
|
||||
} else if let Some(ids) = self.input_ids {
|
||||
// Input IDs case
|
||||
let input_ids = match ids {
|
||||
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
|
||||
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
|
||||
};
|
||||
(None, input_ids)
|
||||
} else {
|
||||
// No input provided
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Add parameters to other - handle both old and new style
|
||||
if let Some(params) = self.parameters {
|
||||
// For generate endpoint, extract max_new_tokens to top level if present
|
||||
let mut params_value = serde_json::to_value(¶ms).unwrap_or(Value::Null);
|
||||
if let Value::Object(ref mut params_map) = params_value {
|
||||
// Move max_new_tokens to top level if it exists
|
||||
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
|
||||
other.insert("max_new_tokens".to_string(), max_new_tokens);
|
||||
}
|
||||
// Move temperature to top level if it exists
|
||||
if let Some(temperature) = params_map.remove("temperature") {
|
||||
other.insert("temperature".to_string(), temperature);
|
||||
}
|
||||
}
|
||||
// Only add parameters if there are remaining fields
|
||||
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
|
||||
{
|
||||
other.insert("parameters".to_string(), params_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add sampling_params if present
|
||||
if let Some(sampling_params) = self.sampling_params {
|
||||
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
|
||||
if !params_value.is_null() {
|
||||
// Extract commonly used fields to top level
|
||||
if let Value::Object(ref params_map) = params_value {
|
||||
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
|
||||
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
|
||||
}
|
||||
if let Some(temperature) = params_map.get("temperature") {
|
||||
other.insert("temperature".to_string(), temperature.clone());
|
||||
}
|
||||
}
|
||||
other.insert("sampling_params".to_string(), params_value);
|
||||
}
|
||||
}
|
||||
|
||||
// Add other fields
|
||||
insert_value!(other,
|
||||
self.stream => "stream",
|
||||
self.return_logprob => "return_logprob"
|
||||
);
|
||||
|
||||
GenerateReqInput {
|
||||
text,
|
||||
input_ids,
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Completion Request Adapter =============
|
||||
|
||||
impl ToPdRequest for CompletionRequest {
|
||||
type Output = GenerateReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
// Convert CompletionRequest to GenerateReqInput
|
||||
let text = match self.prompt {
|
||||
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||
};
|
||||
|
||||
// Map OpenAI parameters to generate parameters
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Create parameters object
|
||||
let mut params = serde_json::Map::new();
|
||||
|
||||
// Map OpenAI fields to internal parameter names
|
||||
insert_if_some!(params,
|
||||
self.max_tokens => "max_new_tokens",
|
||||
self.temperature => "temperature",
|
||||
self.top_p => "top_p",
|
||||
self.n => "best_of",
|
||||
self.logprobs => "top_n_tokens",
|
||||
self.seed => "seed"
|
||||
);
|
||||
|
||||
// Special handling for fields that need transformation
|
||||
if let Some(presence_penalty) = self.presence_penalty {
|
||||
params.insert(
|
||||
"repetition_penalty".to_string(),
|
||||
(1.0 + presence_penalty).into(),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(stop) = self.stop {
|
||||
let stop_sequences = match stop {
|
||||
StringOrArray::String(s) => vec![s],
|
||||
StringOrArray::Array(v) => v,
|
||||
};
|
||||
params.insert("stop".to_string(), stop_sequences.into());
|
||||
}
|
||||
|
||||
if self.echo {
|
||||
params.insert("return_full_text".to_string(), true.into());
|
||||
}
|
||||
|
||||
other.insert("parameters".to_string(), Value::Object(params));
|
||||
|
||||
// Store original model and stream flag
|
||||
insert_value!(other,
|
||||
self.model => "model",
|
||||
self.stream => "stream"
|
||||
);
|
||||
|
||||
GenerateReqInput {
|
||||
text,
|
||||
input_ids: None,
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Chat Completion Request Adapter =============
|
||||
|
||||
impl ToPdRequest for ChatCompletionRequest {
|
||||
type Output = ChatReqInput;
|
||||
|
||||
fn to_pd_request(self) -> Self::Output {
|
||||
let mut other = serde_json::Map::new();
|
||||
|
||||
// Add required fields
|
||||
insert_if_some!(other,
|
||||
Some(&self.messages) => "messages"
|
||||
);
|
||||
|
||||
insert_value!(other,
|
||||
self.model => "model",
|
||||
self.stream => "stream"
|
||||
);
|
||||
|
||||
// Add all optional fields
|
||||
insert_if_some!(other,
|
||||
self.temperature => "temperature",
|
||||
self.top_p => "top_p",
|
||||
self.n => "n",
|
||||
self.stop => "stop",
|
||||
self.max_tokens => "max_tokens",
|
||||
self.max_completion_tokens => "max_completion_tokens",
|
||||
self.presence_penalty => "presence_penalty",
|
||||
self.frequency_penalty => "frequency_penalty",
|
||||
self.logit_bias => "logit_bias",
|
||||
self.user => "user",
|
||||
self.seed => "seed",
|
||||
self.top_logprobs => "top_logprobs",
|
||||
self.response_format => "response_format",
|
||||
self.tools => "tools",
|
||||
self.tool_choice => "tool_choice",
|
||||
self.parallel_tool_calls => "parallel_tool_calls",
|
||||
self.functions => "functions",
|
||||
self.function_call => "function_call"
|
||||
);
|
||||
|
||||
// Handle boolean logprobs flag
|
||||
if self.logprobs {
|
||||
other.insert("logprobs".to_string(), true.into());
|
||||
}
|
||||
|
||||
ChatReqInput {
|
||||
stream: self.stream,
|
||||
bootstrap_host: None,
|
||||
bootstrap_port: None,
|
||||
bootstrap_room: None,
|
||||
other: Value::Object(other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============= Direct routing support for regular router =============
|
||||
|
||||
/// Extension trait for routing without PD conversion
|
||||
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
|
||||
/// Convert to JSON for sending to backend
|
||||
fn to_json(&self) -> Result<Value, serde_json::Error> {
|
||||
serde_json::to_value(self)
|
||||
}
|
||||
|
||||
/// Convert to bytes for legacy routing
|
||||
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
|
||||
let json = serde_json::to_vec(self)?;
|
||||
Ok(bytes::Bytes::from(json))
|
||||
}
|
||||
}
|
||||
|
||||
impl RouteableRequest for GenerateRequest {}
|
||||
impl RouteableRequest for CompletionRequest {}
|
||||
impl RouteableRequest for ChatCompletionRequest {}
|
||||
@@ -1,10 +1,10 @@
|
||||
use crate::pd_router::PDRouter;
|
||||
use crate::pd_types::PDSelectionPolicy;
|
||||
use crate::tree::Tree;
|
||||
use ::metrics::{counter, gauge, histogram};
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
use actix_web::{HttpRequest, HttpResponse};
|
||||
use bytes::Bytes;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
@@ -15,7 +15,7 @@ use std::time::Instant;
|
||||
use tokio;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||
req.headers()
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
@@ -40,6 +40,9 @@ pub enum Router {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
PrefillDecode {
|
||||
pd_router: Arc<PDRouter>,
|
||||
},
|
||||
CacheAware {
|
||||
/*
|
||||
Cache-Aware Load Balancing Router
|
||||
@@ -133,6 +136,13 @@ pub enum PolicyConfig {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy,
|
||||
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
|
||||
decode_urls: Vec<String>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -155,10 +165,24 @@ impl Router {
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
};
|
||||
|
||||
// Wait until all workers are healthy
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
// For PrefillDecode, we need to handle workers differently
|
||||
match &policy_config {
|
||||
PolicyConfig::PrefillDecodeConfig { .. } => {
|
||||
// PD mode doesn't use the worker_urls parameter
|
||||
// We'll validate PD workers separately
|
||||
}
|
||||
_ => {
|
||||
// Wait until all workers are healthy for regular modes
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create router based on policy...
|
||||
Ok(match policy_config {
|
||||
@@ -226,7 +250,7 @@ impl Router {
|
||||
});
|
||||
|
||||
for url in &worker_urls {
|
||||
tree.lock().unwrap().insert(&"".to_string(), url);
|
||||
tree.lock().unwrap().insert("", url);
|
||||
}
|
||||
|
||||
Router::CacheAware {
|
||||
@@ -242,6 +266,26 @@ impl Router {
|
||||
_eviction_thread: Some(eviction_thread),
|
||||
}
|
||||
}
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy,
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
} => {
|
||||
// Create PDRouter instance
|
||||
let pd_router = PDRouter::new(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
selection_policy,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
)?;
|
||||
|
||||
Router::PrefillDecode {
|
||||
pd_router: Arc::new(pd_router),
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -251,16 +295,23 @@ impl Router {
|
||||
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, return empty list since we manage workers differently
|
||||
Arc::new(RwLock::new(Vec::new()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn wait_for_healthy_workers(
|
||||
pub fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
let start_time = std::time::Instant::now();
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
let sync_client = reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
@@ -323,10 +374,14 @@ impl Router {
|
||||
Ok(worker_urls.read().unwrap()[0].clone())
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't need this method as routing is handled by PDRouter
|
||||
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_request(
|
||||
pub async fn send_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
worker_url: &str,
|
||||
@@ -339,7 +394,11 @@ impl Router {
|
||||
// Copy all headers from original request except for /health because it does not need authorization
|
||||
if route != "/health" {
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -433,50 +492,193 @@ impl Router {
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
|
||||
// Convert body to JSON
|
||||
let json: Value = match serde_json::from_slice(body) {
|
||||
Ok(j) => j,
|
||||
Err(_) => {
|
||||
warn!("Failed to parse JSON from request body.");
|
||||
return String::new();
|
||||
pub async fn route_to_all(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
route: &str,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
// Get all worker URLs based on router type
|
||||
let worker_urls = match self {
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, route_to_all is not supported directly
|
||||
// It should be handled by PDRouter if needed
|
||||
return HttpResponse::NotImplemented()
|
||||
.body("route_to_all not implemented for PrefillDecode mode");
|
||||
}
|
||||
_ => self.get_worker_urls().read().unwrap().clone(),
|
||||
};
|
||||
|
||||
match route {
|
||||
"/generate" => {
|
||||
// For /generate, always use the "text" field.
|
||||
match json.get("text").and_then(Value::as_str) {
|
||||
Some(text) => text.to_string(),
|
||||
None => {
|
||||
warn!("No 'text' field found in request body for route /generate.");
|
||||
String::new()
|
||||
}
|
||||
}
|
||||
// Send requests to all workers concurrently
|
||||
let mut tasks = Vec::new();
|
||||
for worker_url in &worker_urls {
|
||||
let mut request_builder = client.post(format!("{}{}", worker_url, route));
|
||||
|
||||
// Copy headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
}
|
||||
"/v1/chat/completions" | "/v1/completions" => {
|
||||
// For these routes, try "messages", then "prompt", then "text".
|
||||
if let Some(messages) = json.get("messages") {
|
||||
serde_json::to_string(messages).unwrap_or_default()
|
||||
} else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
|
||||
prompt.to_string()
|
||||
} else {
|
||||
warn!("Failed to find 'messages', 'prompt' in request body.");
|
||||
String::new()
|
||||
}
|
||||
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
|
||||
// Wait for all responses
|
||||
let results = futures_util::future::join_all(tasks).await;
|
||||
|
||||
// Check if all succeeded
|
||||
let all_success = results.iter().all(|r| {
|
||||
r.as_ref()
|
||||
.map(|res| res.status().is_success())
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
if all_success {
|
||||
HttpResponse::Ok().body("Operation completed on all servers")
|
||||
} else {
|
||||
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_all_loads(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
// For PD mode, delegate to PDRouter
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
return pd_router.get_loads(client).await;
|
||||
}
|
||||
_ => {
|
||||
warn!("Unknown route: {} - defaulting to fallback string", route);
|
||||
String::new()
|
||||
// For non-PD routers, handle normally
|
||||
}
|
||||
}
|
||||
|
||||
let urls = self.get_worker_urls().read().unwrap().clone();
|
||||
let prefill_urls: Vec<String> = Vec::new();
|
||||
let decode_urls = urls;
|
||||
|
||||
// Collect loads from all servers
|
||||
let mut prefill_loads = Vec::new();
|
||||
let mut decode_loads = Vec::new();
|
||||
|
||||
// Get prefill loads
|
||||
for url in &prefill_urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
prefill_loads.push(serde_json::json!({
|
||||
"engine": format!("(Prefill@{})", url),
|
||||
"load": load as i64
|
||||
}));
|
||||
}
|
||||
|
||||
// Get decode loads
|
||||
for url in &decode_urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
decode_loads.push(serde_json::json!({
|
||||
"engine": format!("(Decode@{})", url),
|
||||
"load": load as i64
|
||||
}));
|
||||
}
|
||||
|
||||
HttpResponse::Ok().json(serde_json::json!({
|
||||
"prefill": prefill_loads,
|
||||
"decode": decode_loads
|
||||
}))
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
pub async fn route_typed_request<
|
||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||
>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
|
||||
.body("PD routing should use specialized typed handlers"),
|
||||
_ => {
|
||||
// Handle retries like the original implementation
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
// Extract routing text directly from typed request
|
||||
let text = typed_req.extract_text_for_routing();
|
||||
let is_stream = typed_req.is_stream();
|
||||
|
||||
// Select worker based on text
|
||||
let worker_url = self.select_generate_worker_from_text(&text);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
counter!("sgl_router_retries_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
}
|
||||
|
||||
// Send typed request directly
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
client,
|
||||
req,
|
||||
typed_req,
|
||||
route,
|
||||
&worker_url,
|
||||
is_stream,
|
||||
)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: return Result<String, String> instead of panicking
|
||||
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
|
||||
let text = self.get_text_from_request(&body, route);
|
||||
|
||||
let worker_url = match self {
|
||||
// Helper method to select worker from text
|
||||
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||
match self {
|
||||
Router::RoundRobin {
|
||||
worker_urls,
|
||||
current_index,
|
||||
@@ -506,8 +708,6 @@ impl Router {
|
||||
balance_rel_threshold,
|
||||
..
|
||||
} => {
|
||||
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
|
||||
|
||||
let tree = tree.lock().unwrap();
|
||||
let mut running_queue = running_queue.lock().unwrap();
|
||||
|
||||
@@ -572,35 +772,48 @@ impl Router {
|
||||
|
||||
selected_url
|
||||
}
|
||||
};
|
||||
|
||||
worker_url
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't use this method
|
||||
return "PD_MODE_ERROR".to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_generate_request(
|
||||
// Send typed request directly without conversion
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
worker_url: &str,
|
||||
is_stream: bool,
|
||||
) -> HttpResponse {
|
||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
||||
.unwrap_or(false);
|
||||
let start = Instant::now();
|
||||
|
||||
// Debug: Log what we're sending
|
||||
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
|
||||
debug!("Sending request to {}: {}", route, json_str);
|
||||
}
|
||||
|
||||
let mut request_builder = client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.body(body.to_vec());
|
||||
.json(typed_req); // Use json() directly with typed request
|
||||
|
||||
// Copy all headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
request_builder = request_builder.header(name, value);
|
||||
// Skip Content-Type and Content-Length as .json() sets them
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
||||
request_builder = request_builder.header(&name, &value);
|
||||
}
|
||||
}
|
||||
|
||||
let res = match request_builder.send().await {
|
||||
Ok(res) => res,
|
||||
Err(_) => return HttpResponse::InternalServerError().finish(),
|
||||
Err(e) => {
|
||||
error!("Failed to send request to {}: {}", worker_url, e);
|
||||
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||
@@ -625,6 +838,12 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// Record metrics
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
||||
|
||||
response
|
||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||
let running_queue = Arc::clone(running_queue);
|
||||
@@ -660,70 +879,6 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_generate_request(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
body: &Bytes,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
let worker_url = self.select_generate_worker(body, route);
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.send_generate_request(client, req, body, route, &worker_url)
|
||||
.await;
|
||||
|
||||
if response.status().is_success() {
|
||||
let duration = start.elapsed();
|
||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response =
|
||||
self.send_request(client, &worker_url, "/health", req).await;
|
||||
if health_response.status().is_success() {
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||
.increment(1);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Generate request to {} failed (attempt {}/{})",
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
warn!("Removing failed worker: {}", worker_url);
|
||||
self.remove_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
|
||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||
}
|
||||
|
||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
let (timeout_secs, interval_secs) = match self {
|
||||
Router::Random {
|
||||
@@ -741,10 +896,17 @@ impl Router {
|
||||
interval_secs,
|
||||
..
|
||||
} => (*timeout_secs, *interval_secs),
|
||||
Router::PrefillDecode { .. } => {
|
||||
// For PD mode, we don't support adding workers via this method
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
};
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let client = reqwest::Client::new();
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
loop {
|
||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||
@@ -774,6 +936,9 @@ impl Router {
|
||||
urls.push(worker_url.to_string());
|
||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// If cache aware, initialize the queues for the new worker
|
||||
@@ -797,7 +962,7 @@ impl Router {
|
||||
.insert(worker_url.to_string(), 0);
|
||||
|
||||
// Add worker to tree
|
||||
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
|
||||
tree.lock().unwrap().insert("", worker_url);
|
||||
}
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
@@ -850,6 +1015,10 @@ impl Router {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Router::PrefillDecode { .. } => {
|
||||
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// if cache aware, remove the worker from the tree
|
||||
@@ -875,4 +1044,133 @@ impl Router {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
.get("load")
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as isize),
|
||||
Err(e) => {
|
||||
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Ok(res) => {
|
||||
debug!(
|
||||
"Worker {} returned non-success status: {}",
|
||||
worker_url,
|
||||
res.status()
|
||||
);
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PD-specific wrapper methods that delegate to PDRouter
|
||||
pub async fn route_pd_health_generate(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.health_generate(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_generate_typed(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: crate::pd_types::GenerateReqInput,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router
|
||||
.route_generate(&pd_router.http_client, req, typed_req, route)
|
||||
.await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_chat_typed(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
typed_req: crate::pd_types::ChatReqInput,
|
||||
route: &str,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router
|
||||
.route_chat(&pd_router.http_client, req, typed_req, route)
|
||||
.await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_server_info(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
_req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_server_info(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_models(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_models(&pd_router.http_client, req).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.flush_cache(&pd_router.http_client).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_pd_model_info(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
req: &HttpRequest,
|
||||
) -> HttpResponse {
|
||||
match self {
|
||||
Router::PrefillDecode { pd_router } => {
|
||||
pd_router.get_model_info(&pd_router.http_client, req).await
|
||||
}
|
||||
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::prometheus::{self, PrometheusConfig};
|
||||
use crate::request_adapter::ToPdRequest;
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use actix_web::{
|
||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
|
||||
pub struct AppState {
|
||||
router: Arc<Router>,
|
||||
client: Client,
|
||||
is_pd_mode: bool, // Add flag to track PD mode
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -28,9 +30,16 @@ impl AppState {
|
||||
client: Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Check if this is PD mode from policy config
|
||||
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
|
||||
|
||||
// Create router based on policy
|
||||
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
||||
Ok(Self { router, client })
|
||||
Ok(Self {
|
||||
router,
|
||||
client,
|
||||
is_pd_mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
|
||||
}
|
||||
|
||||
// Custom error handler for JSON payload errors.
|
||||
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||
error::ErrorPayloadTooLarge("Payload too large")
|
||||
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||
error!("JSON payload error: {:?}", err);
|
||||
match &err {
|
||||
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
||||
error!(
|
||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||
length, limit
|
||||
);
|
||||
error::ErrorPayloadTooLarge(format!(
|
||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||
length, limit
|
||||
))
|
||||
}
|
||||
error::JsonPayloadError::Overflow { limit } => {
|
||||
error!("Payload overflow: exceeds limit of {} bytes", limit);
|
||||
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
||||
}
|
||||
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
// Check if we're in PD mode
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, check health on all servers
|
||||
data.router
|
||||
.route_pd_health_generate(&data.client, &req)
|
||||
.await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, aggregate info from both prefill and decode servers
|
||||
data.router.get_pd_server_info(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode - return first server's info
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, return models from the first prefill server
|
||||
data.router.get_pd_models(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, get model info from the first prefill server
|
||||
data.router.get_pd_model_info(&data.client, &req).await
|
||||
} else {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/generate")
|
||||
.await
|
||||
async fn generate(
|
||||
req: HttpRequest,
|
||||
body: web::Json<GenerateRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/generate")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
async fn v1_chat_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
|
||||
.await
|
||||
body: web::Json<ChatCompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
async fn v1_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/completions")
|
||||
.await
|
||||
body: web::Json<CompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/completions")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/add_worker")]
|
||||
@@ -153,6 +254,25 @@ async fn remove_worker(
|
||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, flush cache on both prefill and decode servers
|
||||
data.router.route_pd_flush_cache(&data.client).await
|
||||
} else {
|
||||
// Route to all workers for cache flushing
|
||||
data.router
|
||||
.route_to_all(&data.client, "/flush_cache", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Get loads from all workers
|
||||
data.router.get_all_loads(&data.client, &req).await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
@@ -163,6 +283,7 @@ pub struct ServerConfig {
|
||||
pub log_dir: Option<String>,
|
||||
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
||||
pub prometheus_config: Option<PrometheusConfig>,
|
||||
pub request_timeout_secs: u64,
|
||||
}
|
||||
|
||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
|
||||
let client = Client::builder()
|
||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.service(add_worker)
|
||||
.service(remove_worker)
|
||||
.service(list_workers)
|
||||
// Default handler for unmatched routes.
|
||||
.service(flush_cache)
|
||||
.service(get_loads)
|
||||
.default_service(web::route().to(sink_handler))
|
||||
})
|
||||
.bind_auto_h2c((config.host, config.port))?
|
||||
|
||||
904
sgl-router/tests/test_pd_routing.rs
Normal file
904
sgl-router/tests/test_pd_routing.rs
Normal file
@@ -0,0 +1,904 @@
|
||||
//! Comprehensive tests for PrefillDecode (PD) routing functionality
|
||||
//!
|
||||
//! This test suite covers:
|
||||
//! - Phase 1: Basic PD router creation and configuration
|
||||
//! - Phase 2: Bootstrap injection and request handling
|
||||
//! - Phase 3: Cache-aware selection (when implemented)
|
||||
//!
|
||||
//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type.
|
||||
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
|
||||
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
#[derive(Debug)]
|
||||
struct PDRequest {
|
||||
pub is_stream: bool,
|
||||
pub batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl PDRequest {
|
||||
// Extract PD-relevant info from JSON for testing
|
||||
pub fn from_json(json: &serde_json::Value) -> Self {
|
||||
let is_stream = json
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect batch size from text or input_ids
|
||||
let batch_size = if let Some(text) = json.get("text") {
|
||||
text.as_array().map(|arr| arr.len())
|
||||
} else if let Some(input_ids) = json.get("input_ids") {
|
||||
input_ids.as_array().map(|arr| arr.len())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
PDRequest {
|
||||
is_stream,
|
||||
batch_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 1: Basic PD Components and Router Creation
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_engine_info_creation() {
|
||||
// Test EngineInfo creation for prefill servers
|
||||
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
match prefill_engine.engine_type {
|
||||
EngineType::Prefill => (),
|
||||
_ => panic!("Expected Prefill engine type"),
|
||||
}
|
||||
assert_eq!(prefill_engine.url, "http://prefill:8080");
|
||||
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
|
||||
assert_eq!(prefill_engine.get_hostname(), "prefill");
|
||||
|
||||
// Test EngineInfo creation for decode servers
|
||||
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
|
||||
match decode_engine.engine_type {
|
||||
EngineType::Decode => (),
|
||||
_ => panic!("Expected Decode engine type"),
|
||||
}
|
||||
assert_eq!(decode_engine.url, "http://decode:8080");
|
||||
assert_eq!(decode_engine.bootstrap_port, None);
|
||||
assert_eq!(decode_engine.get_hostname(), "decode");
|
||||
|
||||
// Test API path generation
|
||||
assert_eq!(
|
||||
prefill_engine.api_path("/generate"),
|
||||
"http://prefill:8080/generate"
|
||||
);
|
||||
assert_eq!(
|
||||
prefill_engine.api_path("health"),
|
||||
"http://prefill:8080/health"
|
||||
);
|
||||
assert_eq!(
|
||||
decode_engine.api_path("/v1/chat/completions"),
|
||||
"http://decode:8080/v1/chat/completions"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_selection_policies() {
|
||||
// Test all PD selection policy variants
|
||||
// Note: These policies are only used when pd_disaggregated=true
|
||||
let policies = vec![
|
||||
PDSelectionPolicy::Random,
|
||||
PDSelectionPolicy::PowerOfTwo,
|
||||
PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
},
|
||||
];
|
||||
|
||||
for policy in policies {
|
||||
// Verify each policy can be created and matched
|
||||
match &policy {
|
||||
PDSelectionPolicy::Random => {
|
||||
assert!(matches!(policy, PDSelectionPolicy::Random));
|
||||
}
|
||||
PDSelectionPolicy::PowerOfTwo => {
|
||||
assert!(matches!(policy, PDSelectionPolicy::PowerOfTwo));
|
||||
}
|
||||
PDSelectionPolicy::CacheAware {
|
||||
cache_threshold, ..
|
||||
} => {
|
||||
assert!(*cache_threshold >= 0.0 && *cache_threshold <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_router_configuration() {
|
||||
// Test PrefillDecodeConfig creation with various policies
|
||||
// This config is used when pd_disaggregated=true
|
||||
let configs = vec![
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::Random,
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8080".to_string(), Some(9000)),
|
||||
("http://prefill2:8080".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8080".to_string(),
|
||||
"http://decode2:8080".to_string(),
|
||||
],
|
||||
timeout_secs: 10,
|
||||
interval_secs: 1,
|
||||
},
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::PowerOfTwo,
|
||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||
decode_urls: vec!["http://decode:8080".to_string()],
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
},
|
||||
PolicyConfig::PrefillDecodeConfig {
|
||||
selection_policy: PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 20,
|
||||
balance_rel_threshold: 1.2,
|
||||
},
|
||||
prefill_urls: vec![
|
||||
("http://p1:8080".to_string(), Some(9000)),
|
||||
("http://p2:8080".to_string(), Some(9001)),
|
||||
("http://p3:8080".to_string(), Some(9002)),
|
||||
],
|
||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||
timeout_secs: 10,
|
||||
interval_secs: 2,
|
||||
},
|
||||
];
|
||||
|
||||
for config in configs {
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
let result = Router::new(vec![], config);
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err();
|
||||
// Error should be about health/timeout, not configuration
|
||||
assert!(
|
||||
error_msg.contains("healthy") || error_msg.contains("timeout"),
|
||||
"Unexpected error: {}",
|
||||
error_msg
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 2: Bootstrap Injection and Request Handling
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pd_request_from_json() {
|
||||
// Test PDRequest parsing from single text request
|
||||
let single_json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&single_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test PDRequest parsing from batch text request
|
||||
let batch_json = json!({
|
||||
"text": ["Hello", "World", "Test"],
|
||||
"stream": true,
|
||||
"temperature": 0.5
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&batch_json);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(3));
|
||||
|
||||
// Test PDRequest parsing from input_ids request
|
||||
let ids_json = json!({
|
||||
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&ids_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(2));
|
||||
|
||||
// Test PDRequest parsing from chat request
|
||||
let chat_json = json!({
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&chat_json);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_simulation() {
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
// Simulate bootstrap injection for single request
|
||||
let mut single_json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false,
|
||||
"temperature": 0.7
|
||||
});
|
||||
|
||||
// Simulate what inject_bootstrap_fields would do
|
||||
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
|
||||
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
|
||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||
|
||||
// Verify bootstrap fields are added correctly
|
||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||
assert_eq!(single_json["bootstrap_port"], 9000);
|
||||
assert!(single_json["bootstrap_room"].is_u64());
|
||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||
|
||||
// Simulate bootstrap injection for batch request
|
||||
let mut batch_json = json!({
|
||||
"text": ["Hello", "World", "Test"],
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let batch_size = 3;
|
||||
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
assert!(batch_json["bootstrap_host"].is_array());
|
||||
assert_eq!(
|
||||
batch_json["bootstrap_host"].as_array().unwrap().len(),
|
||||
batch_size
|
||||
);
|
||||
assert!(batch_json["bootstrap_port"].is_array());
|
||||
assert!(batch_json["bootstrap_room"].is_array());
|
||||
assert_eq!(batch_json["stream"], true); // Original field preserved
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_serialization() {
|
||||
// Test that requests can be properly serialized and deserialized
|
||||
let request = json!({
|
||||
"text": "Test prompt",
|
||||
"stream": false,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.5,
|
||||
"bootstrap_host": "prefill1",
|
||||
"bootstrap_port": 9000,
|
||||
"bootstrap_room": 12345u64
|
||||
});
|
||||
|
||||
// Convert to bytes (as would happen in the router)
|
||||
let bytes = serde_json::to_vec(&request).unwrap();
|
||||
|
||||
// Parse back from bytes
|
||||
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||
|
||||
// Verify all fields are preserved
|
||||
assert_eq!(parsed["text"], "Test prompt");
|
||||
assert_eq!(parsed["stream"], false);
|
||||
assert_eq!(parsed["temperature"], 0.7);
|
||||
assert_eq!(parsed["max_tokens"], 100);
|
||||
assert_eq!(parsed["bootstrap_host"], "prefill1");
|
||||
assert_eq!(parsed["bootstrap_port"], 9000);
|
||||
assert_eq!(parsed["bootstrap_room"], 12345);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_engine_info_hostname_extraction() {
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
("http://10.0.0.1:8080", "10.0.0.1"),
|
||||
("https://api.example.com:443", "api.example.com"),
|
||||
("http://prefill-server", "prefill-server"),
|
||||
("http://[::1]:8080", "["), // IPv6 edge case
|
||||
("prefill:8080", "prefill"), // No protocol
|
||||
];
|
||||
|
||||
for (url, expected_hostname) in test_cases {
|
||||
let engine = EngineInfo::new_prefill(url.to_string(), None);
|
||||
assert_eq!(engine.get_hostname(), expected_hostname);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_request_edge_cases() {
|
||||
// Test empty request
|
||||
let empty_json = json!({});
|
||||
let pd_req = PDRequest::from_json(&empty_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test request with only stream field
|
||||
let stream_only = json!({
|
||||
"stream": true
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&stream_only);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test request with empty text array
|
||||
let empty_batch = json!({
|
||||
"text": []
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&empty_batch);
|
||||
assert_eq!(pd_req.batch_size, Some(0));
|
||||
|
||||
// Test request with non-array text (should be None)
|
||||
let non_array_text = json!({
|
||||
"text": "single string"
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&non_array_text);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 2: Background Load Monitoring Tests
|
||||
// ========================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_background_load_monitoring() {
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Create a watch channel for testing
|
||||
let (tx, rx) = watch::channel(HashMap::new());
|
||||
|
||||
// Simulate load updates
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://prefill1:8080".to_string(), 10);
|
||||
loads.insert("http://prefill2:8080".to_string(), 20);
|
||||
loads.insert("http://decode1:8080".to_string(), 5);
|
||||
loads.insert("http://decode2:8080".to_string(), 15);
|
||||
|
||||
// Send the loads
|
||||
tx.send(loads.clone()).unwrap();
|
||||
|
||||
// Verify receiver gets the update
|
||||
let received_loads = rx.borrow();
|
||||
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
|
||||
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
|
||||
assert_eq!(received_loads.get("http://decode1:8080"), Some(&5));
|
||||
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_power_of_two_load_selection() {
|
||||
// Test the power-of-two selection logic with different load scenarios
|
||||
|
||||
// Scenario 1: Clear winner for both prefill and decode
|
||||
let _loads = vec![
|
||||
("prefill1", 100),
|
||||
("prefill2", 10), // Should be selected
|
||||
("decode1", 50),
|
||||
("decode2", 5), // Should be selected
|
||||
];
|
||||
|
||||
// In actual implementation, the lower load should be selected
|
||||
assert!(10 < 100);
|
||||
assert!(5 < 50);
|
||||
|
||||
// Scenario 2: Equal loads (should select first)
|
||||
let _equal_loads = vec![
|
||||
("prefill1", 20),
|
||||
("prefill2", 20), // Either could be selected
|
||||
("decode1", 30),
|
||||
("decode2", 30), // Either could be selected
|
||||
];
|
||||
|
||||
// When loads are equal, <= comparison means first is selected
|
||||
assert!(20 <= 20);
|
||||
assert!(30 <= 30);
|
||||
|
||||
// Scenario 3: Missing load data (should default to usize::MAX)
|
||||
// This tests the unwrap_or(usize::MAX) behavior
|
||||
let missing_load = usize::MAX;
|
||||
assert!(10 < missing_load);
|
||||
assert!(missing_load > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_monitoring_configuration() {
|
||||
// Test that load monitoring is only enabled for PowerOfTwo policy
|
||||
let policies = vec![
|
||||
(PDSelectionPolicy::Random, false),
|
||||
(PDSelectionPolicy::PowerOfTwo, true),
|
||||
(
|
||||
PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
},
|
||||
false,
|
||||
),
|
||||
];
|
||||
|
||||
for (policy, should_monitor) in policies {
|
||||
match policy {
|
||||
PDSelectionPolicy::PowerOfTwo => assert!(should_monitor),
|
||||
_ => assert!(!should_monitor),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_watch_channel_behavior() {
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Test watch channel's broadcast behavior
|
||||
let (tx, rx1) = watch::channel(HashMap::new());
|
||||
let rx2 = rx1.clone();
|
||||
|
||||
// Initial state - empty map
|
||||
assert!(rx1.borrow().is_empty());
|
||||
assert!(rx2.borrow().is_empty());
|
||||
|
||||
// Update 1
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("worker1".to_string(), 10);
|
||||
tx.send(loads.clone()).unwrap();
|
||||
|
||||
// Both receivers see the update
|
||||
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
|
||||
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
|
||||
|
||||
// Update 2 - overwrites previous
|
||||
loads.insert("worker1".to_string(), 20);
|
||||
loads.insert("worker2".to_string(), 30);
|
||||
tx.send(loads).unwrap();
|
||||
|
||||
// Both receivers see the latest state
|
||||
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
|
||||
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Tests based on bench_one_batch_server.py patterns
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_formats() {
|
||||
// Based on bench_one_batch_server.py request patterns
|
||||
|
||||
// Test 1: Batch request with input_ids (most common in benchmarks)
|
||||
let batch_request = json!({
|
||||
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 16,
|
||||
"ignore_eos": true,
|
||||
},
|
||||
"return_logprob": false,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&batch_request);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(3));
|
||||
|
||||
// Test 2: Request with return_logprob (critical for PD)
|
||||
let logprob_request = json!({
|
||||
"input_ids": [[1, 2, 3]],
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
"return_logprob": true,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
assert_eq!(logprob_request["return_logprob"], true);
|
||||
assert_eq!(logprob_request["stream"], false);
|
||||
|
||||
// Test 3: Large batch sizes from benchmark
|
||||
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
|
||||
for bs in batch_sizes {
|
||||
let request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3]; bs],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&request);
|
||||
assert_eq!(pd_req.batch_size, Some(bs));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_handling() {
|
||||
// Test various sampling parameters from bench_one_batch_server.py
|
||||
let sampling_params_variations = vec![
|
||||
json!({
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": true
|
||||
}),
|
||||
json!({
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 16,
|
||||
"ignore_eos": false,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.5
|
||||
}),
|
||||
json!({
|
||||
"temperature": 1.0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": "$$ANY$$" // Structured output
|
||||
}),
|
||||
];
|
||||
|
||||
for params in sampling_params_variations {
|
||||
let request = json!({
|
||||
"input_ids": [[1, 2, 3]],
|
||||
"sampling_params": params.clone(),
|
||||
"stream": false
|
||||
});
|
||||
|
||||
// Verify params are preserved
|
||||
assert_eq!(request["sampling_params"], params);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_response_parsing() {
|
||||
// Test SSE format parsing from streaming responses
|
||||
let sse_chunks = vec![
|
||||
"data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
||||
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
||||
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
||||
"data: [DONE]",
|
||||
];
|
||||
|
||||
for chunk in &sse_chunks[..3] {
|
||||
assert!(chunk.starts_with("data: "));
|
||||
let json_str = &chunk[6..]; // Skip "data: "
|
||||
let parsed: serde_json::Value = serde_json::from_str(json_str).unwrap();
|
||||
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
|
||||
}
|
||||
|
||||
// Test [DONE] detection
|
||||
assert_eq!(sse_chunks[3], "data: [DONE]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ttft_calculation() {
|
||||
// Test Time To First Token calculation pattern
|
||||
let first_token_response = json!({
|
||||
"text": "Hello",
|
||||
"meta_info": {
|
||||
"completion_tokens": 1,
|
||||
"finish_reason": null
|
||||
}
|
||||
});
|
||||
|
||||
// TTFT is calculated when completion_tokens == 1
|
||||
assert_eq!(first_token_response["meta_info"]["completion_tokens"], 1);
|
||||
assert!(first_token_response["meta_info"]["finish_reason"].is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_throughput_metrics() {
|
||||
// Test throughput calculation patterns from bench_one_batch_server.py
|
||||
let batch_size = 16;
|
||||
let input_len = 1024;
|
||||
let output_len = 16;
|
||||
let ttft = 0.5; // seconds
|
||||
let total_latency = 2.0; // seconds
|
||||
|
||||
// Input throughput = batch_size * input_len / ttft
|
||||
let input_throughput = (batch_size as f64) * (input_len as f64) / ttft;
|
||||
assert!((input_throughput - 32768.0).abs() < 0.01);
|
||||
|
||||
// Output throughput = batch_size * output_len / (latency - ttft)
|
||||
let output_throughput = (batch_size as f64) * (output_len as f64) / (total_latency - ttft);
|
||||
assert!((output_throughput - 170.67).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_handling() {
|
||||
// Test error response format from bench_one_batch_server.py
|
||||
let error_response = json!({
|
||||
"error": "Request has failed. Invalid input format."
|
||||
});
|
||||
|
||||
assert!(error_response.get("error").is_some());
|
||||
assert!(error_response["error"].as_str().unwrap().contains("failed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_structured_output_request() {
|
||||
// Test structured output format (json_schema)
|
||||
let structured_request = json!({
|
||||
"text": "What is the capital of France? Answer in JSON.",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": "$$ANY$$"
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
structured_request["sampling_params"]["json_schema"],
|
||||
"$$ANY$$"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": true
|
||||
},
|
||||
"return_logprob": true,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Simulate bootstrap injection
|
||||
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
let batch_size = 16;
|
||||
|
||||
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
benchmark_request["bootstrap_room"] =
|
||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||
|
||||
// Verify bootstrap fields match batch size
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_host"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_port"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_room"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
|
||||
// Verify original fields are preserved
|
||||
assert_eq!(benchmark_request["return_logprob"], true);
|
||||
assert_eq!(benchmark_request["stream"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_info_response_format() {
|
||||
// Test server info format expected by bench_one_batch_server.py
|
||||
let server_info = json!({
|
||||
"internal_states": [{
|
||||
"avg_spec_accept_length": 3.5,
|
||||
"last_gen_throughput": 2048.5,
|
||||
"load": 16
|
||||
}],
|
||||
"prefill": [
|
||||
{"url": "http://prefill1:8080", "load": 10},
|
||||
{"url": "http://prefill2:8080", "load": 20}
|
||||
],
|
||||
"decode": [
|
||||
{"url": "http://decode1:8080", "load": 5},
|
||||
{"url": "http://decode2:8080", "load": 15}
|
||||
]
|
||||
});
|
||||
|
||||
// Verify structure matches what benchmark expects
|
||||
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
|
||||
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
|
||||
assert!(server_info["prefill"].is_array());
|
||||
assert!(server_info["decode"].is_array());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Comprehensive Endpoint Coverage Test
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pd_endpoints_coverage() {
|
||||
// Document all endpoints from Python mini_lb.py and verify implementation status
|
||||
let implemented_endpoints = vec![
|
||||
("/health", "GET", true),
|
||||
("/health_generate", "GET", true), // Note: Python uses POST, we use GET
|
||||
("/get_server_info", "GET", true),
|
||||
("/v1/models", "GET", true),
|
||||
("/get_model_info", "GET", true),
|
||||
("/generate", "POST", true),
|
||||
("/v1/chat/completions", "POST", true),
|
||||
("/v1/completions", "POST", true),
|
||||
("/flush_cache", "POST", true),
|
||||
("/get_loads", "GET", true),
|
||||
("/register", "POST", false), // NOT IMPLEMENTED - needs dynamic worker management
|
||||
];
|
||||
|
||||
let implemented_count = implemented_endpoints
|
||||
.iter()
|
||||
.filter(|(_, _, impl_status)| *impl_status)
|
||||
.count();
|
||||
let total_count = implemented_endpoints.len();
|
||||
|
||||
// We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2)
|
||||
assert_eq!(implemented_count, 10);
|
||||
assert_eq!(total_count, 11);
|
||||
|
||||
// Document the missing endpoint
|
||||
let missing: Vec<_> = implemented_endpoints
|
||||
.iter()
|
||||
.filter(|(_, _, impl_status)| !impl_status)
|
||||
.map(|(endpoint, method, _)| format!("{} {}", method, endpoint))
|
||||
.collect();
|
||||
|
||||
assert_eq!(missing, vec!["POST /register"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
|
||||
for batch_size in large_batch_sizes {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Simulate a large batch request
|
||||
let mut large_batch_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Simulate bootstrap injection
|
||||
let prefill_info =
|
||||
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
large_batch_request["bootstrap_host"] =
|
||||
json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||
large_batch_request["bootstrap_port"] =
|
||||
json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||
.map(|_| rand::thread_rng().gen::<u64>())
|
||||
.collect::<Vec<_>>());
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Verify bootstrap fields are correctly sized
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_host"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_port"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_room"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
|
||||
// Bootstrap injection should be reasonably fast even for large batches
|
||||
println!(
|
||||
"Bootstrap injection for batch_size {} took {:?}",
|
||||
batch_size, elapsed
|
||||
);
|
||||
assert!(
|
||||
elapsed.as_millis() < 1000,
|
||||
"Bootstrap injection took too long for batch size {}",
|
||||
batch_size
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_payload_size_calculation() {
|
||||
// Test payload size estimation for bench_one_batch_server.py scenarios
|
||||
let test_cases = vec![
|
||||
(1, 1024, 16), // Small batch
|
||||
(16, 1024, 16), // Medium batch
|
||||
(64, 1024, 16), // Large batch
|
||||
(8192, 4096, 5), // Benchmark scenario
|
||||
];
|
||||
|
||||
for (batch_size, input_len, _output_len) in test_cases {
|
||||
// Estimate payload size (rough calculation)
|
||||
// Each token is ~4 bytes (i32), plus JSON overhead
|
||||
let tokens_size = batch_size * input_len * 4; // 4 bytes per token
|
||||
let json_overhead = batch_size * 100; // ~100 bytes overhead per request
|
||||
let total_size = tokens_size + json_overhead;
|
||||
|
||||
println!(
|
||||
"Batch size: {}, Input len: {}, Estimated payload: {} MB",
|
||||
batch_size,
|
||||
input_len,
|
||||
total_size / (1024 * 1024)
|
||||
);
|
||||
|
||||
// For the benchmark case (8192, 4096), this should be ~134 MB
|
||||
if batch_size == 8192 && input_len == 4096 {
|
||||
assert!(
|
||||
total_size > 100 * 1024 * 1024,
|
||||
"Benchmark payload should be > 100MB"
|
||||
);
|
||||
assert!(
|
||||
total_size < 200 * 1024 * 1024,
|
||||
"Benchmark payload should be < 200MB"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||
// Document the mapping from PolicyType to PDSelectionPolicy
|
||||
// This mapping happens in lib.rs when pd_disaggregated=true
|
||||
|
||||
// PolicyType::Random -> PDSelectionPolicy::Random
|
||||
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
|
||||
// PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... }
|
||||
// PolicyType::RoundRobin -> ERROR (not supported in PD mode)
|
||||
|
||||
// Test that PDSelectionPolicy doesn't include RoundRobin
|
||||
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
||||
assert_eq!(
|
||||
pd_policy_count, 3,
|
||||
"PDSelectionPolicy should have exactly 3 variants"
|
||||
);
|
||||
|
||||
// Verify that each PDSelectionPolicy variant can be created
|
||||
let _random = PDSelectionPolicy::Random;
|
||||
let _po2 = PDSelectionPolicy::PowerOfTwo;
|
||||
let _cache_aware = PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
};
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user