From 09ae5b20f3123487f36097d284a1f535cd267e7b Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Wed, 18 Jun 2025 11:28:15 -0700 Subject: [PATCH] Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096) --- python/sglang/srt/disaggregation/mini_lb.py | 28 +- sgl-router/Cargo.toml | 4 +- .../py_src/sglang_router/launch_router.py | 113 +- sgl-router/py_src/sglang_router/router.py | 14 +- sgl-router/py_test/test_launch_router.py | 118 +- sgl-router/src/lib.rs | 103 +- sgl-router/src/openai_api_types.rs | 704 ++++++++++++ sgl-router/src/pd_router.rs | 1002 +++++++++++++++++ sgl-router/src/pd_types.rs | 245 ++++ sgl-router/src/request_adapter.rs | 264 +++++ sgl-router/src/router.rs | 544 +++++++-- sgl-router/src/server.rs | 189 +++- sgl-router/tests/test_pd_routing.rs | 904 +++++++++++++++ 13 files changed, 4045 insertions(+), 187 deletions(-) create mode 100644 sgl-router/src/openai_api_types.rs create mode 100644 sgl-router/src/pd_router.rs create mode 100644 sgl-router/src/pd_types.rs create mode 100644 sgl-router/src/request_adapter.rs create mode 100644 sgl-router/tests/test_pd_routing.rs diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 8e3371c73..d91598e4f 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -218,15 +218,39 @@ async def get_server_info(): ) prefill_infos = [] decode_infos = [] + all_internal_states = [] + async with aiohttp.ClientSession() as session: for server in chain(prefill_servers): server_info = await session.get(f"{server}/get_server_info") prefill_infos.append(await server_info.json()) for server in chain(decode_servers): server_info = await session.get(f"{server}/get_server_info") - decode_infos.append(await server_info.json()) + info_json = await server_info.json() + decode_infos.append(info_json) + # Extract internal_states from decode servers + if "internal_states" in info_json: + all_internal_states.extend(info_json["internal_states"]) - return {"prefill": prefill_infos, "decode": decode_infos} + # Return format expected by bench_one_batch_server.py + if all_internal_states: + return { + "internal_states": all_internal_states, + "prefill": prefill_infos, + "decode": decode_infos, + } + else: + # Fallback with dummy data if no internal states found + return { + "internal_states": [ + { + "last_gen_throughput": 0.0, + "avg_spec_accept_length": None, + } + ], + "prefill": prefill_infos, + "decode": decode_infos, + } @app.get("/get_model_info") diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index 30b248e87..afc558f1c 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] } clap = { version = "4.4", features = ["derive"] } bytes = "1.8.0" rand = "0.8.5" -reqwest = { version = "0.12.8", features = ["stream", "blocking"] } +reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] } futures-util = "0.3" serde_json = "1.0" pyo3 = { version = "0.22.5", features = ["extension-module"] } @@ -33,6 +33,8 @@ futures = "0.3" # Added for metrics metrics = "0.24.2" metrics-exporter-prometheus = "0.17.0" +# Added for request tracing +uuid = { version = "1.10", features = ["v4", "serde"] } [profile.release] lto = "thin" codegen-units = 1 diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index 4f036a253..74000ccbe 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -31,6 +31,13 @@ class RouterArgs: host: str = "127.0.0.1" port: int = 30000 + # PD-specific configuration + pd_disaggregated: bool = False # Enable PD disaggregated mode + prefill_urls: List[tuple] = dataclasses.field( + default_factory=list + ) # List of (url, bootstrap_port) + decode_urls: List[str] = dataclasses.field(default_factory=list) + # Routing policy policy: str = "cache_aware" worker_startup_timeout_secs: int = 300 @@ -40,7 +47,7 @@ class RouterArgs: balance_rel_threshold: float = 1.0001 eviction_interval: int = 60 max_tree_size: int = 2**24 - max_payload_size: int = 4 * 1024 * 1024 # 4MB + max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches verbose: bool = False log_dir: Optional[str] = None # Service discovery configuration @@ -95,8 +102,29 @@ class RouterArgs: f"--{prefix}policy", type=str, default=RouterArgs.policy, - choices=["random", "round_robin", "cache_aware"], - help="Load balancing policy to use", + choices=["random", "round_robin", "cache_aware", "power_of_two"], + help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode", + ) + + # PD-specific arguments + parser.add_argument( + f"--{prefix}pd-disaggregated", + action="store_true", + help="Enable PD (Prefill-Decode) disaggregated mode", + ) + parser.add_argument( + f"--{prefix}prefill", + nargs=2, + action="append", + metavar=("URL", "BOOTSTRAP_PORT"), + help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.", + ) + parser.add_argument( + f"--{prefix}decode", + nargs=1, + action="append", + metavar=("URL",), + help="Decode server URL. Can be specified multiple times.", ) parser.add_argument( f"--{prefix}worker-startup-timeout-secs", @@ -205,11 +233,19 @@ class RouterArgs: use_router_prefix: If True, look for arguments with 'router-' prefix """ prefix = "router_" if use_router_prefix else "" - worker_urls = args.worker_urls if args.worker_urls is not None else [] + worker_urls = getattr(args, "worker_urls", []) + + # Parse PD URLs + prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None)) + decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None)) + return cls( worker_urls=worker_urls, host=args.host, port=args.port, + pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False), + prefill_urls=prefill_urls, + decode_urls=decode_urls, policy=getattr(args, f"{prefix}policy"), worker_startup_timeout_secs=getattr( args, f"{prefix}worker_startup_timeout_secs" @@ -247,6 +283,46 @@ class RouterArgs: selector[key] = value return selector + @staticmethod + def _parse_prefill_urls(prefill_list): + """Parse prefill URLs from --prefill arguments. + + Format: --prefill URL BOOTSTRAP_PORT + Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none + """ + if not prefill_list: + return [] + + prefill_urls = [] + for url, bootstrap_port_str in prefill_list: + # Handle 'none' as None + if bootstrap_port_str.lower() == "none": + bootstrap_port = None + else: + try: + bootstrap_port = int(bootstrap_port_str) + except ValueError: + raise ValueError( + f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" + ) + + prefill_urls.append((url, bootstrap_port)) + + return prefill_urls + + @staticmethod + def _parse_decode_urls(decode_list): + """Parse decode URLs from --decode arguments. + + Format: --decode URL + Example: --decode http://decode1:8081 --decode http://decode2:8081 + """ + if not decode_list: + return [] + + # decode_list is a list of single-element lists due to nargs=1 + return [url[0] for url in decode_list] + def policy_from_str(policy_str: str) -> PolicyType: """Convert policy string to PolicyType enum.""" @@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType: "random": PolicyType.Random, "round_robin": PolicyType.RoundRobin, "cache_aware": PolicyType.CacheAware, + "power_of_two": PolicyType.PowerOfTwo, } return policy_map[policy_str] @@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: else: router_args = args + # Validate configuration based on mode + if router_args.pd_disaggregated: + # Validate PD configuration + if not router_args.prefill_urls: + raise ValueError("PD disaggregated mode requires --prefill") + if not router_args.decode_urls: + raise ValueError("PD disaggregated mode requires --decode") + + # Create router with unified constructor router = Router( - worker_urls=router_args.worker_urls, + worker_urls=( + router_args.worker_urls if not router_args.pd_disaggregated else [] + ), host=router_args.host, port=router_args.port, policy=policy_from_str(router_args.policy), @@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: service_discovery_namespace=router_args.service_discovery_namespace, prometheus_port=router_args.prometheus_port, prometheus_host=router_args.prometheus_host, + pd_disaggregated=router_args.pd_disaggregated, + prefill_urls=( + router_args.prefill_urls if router_args.pd_disaggregated else None + ), + decode_urls=( + router_args.decode_urls if router_args.pd_disaggregated else None + ), ) router.start() @@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is multi-node setups or when you want to start workers and router separately. Examples: + # Regular mode python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 - python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2 + + # PD disaggregated mode + python -m sglang_router.launch_router --pd-disaggregated \\ + --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\ + --decode http://decode1:8001 --decode http://decode2:8001 \\ + --policy cache_aware """, formatter_class=CustomHelpFormatter, diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index c189cd587..5fd5d8788 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -15,6 +15,7 @@ class Router: - PolicyType.Random: Randomly select workers - PolicyType.RoundRobin: Distribute requests in round-robin fashion - PolicyType.CacheAware: Distribute requests based on cache state and load balance + - PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only) host: Host address to bind the router server. Default: '127.0.0.1' port: Port number to bind the router server. Default: 3001 worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 @@ -28,7 +29,7 @@ class Router: AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001 eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware routing. Default: 60 - max_payload_size: Maximum payload size in bytes. Default: 4MB + max_payload_size: Maximum payload size in bytes. Default: 256MB max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24 verbose: Enable verbose logging. Default: False log_dir: Directory to store log files. If None, logs are only output to console. Default: None @@ -42,6 +43,9 @@ class Router: watches pods across all namespaces (requires cluster-wide permissions). Default: None prometheus_port: Port to expose Prometheus metrics. Default: None prometheus_host: Host address to bind the Prometheus metrics server. Default: None + pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False + prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only) + decode_urls: List of URLs for decode servers (PD mode only) """ def __init__( @@ -57,7 +61,7 @@ class Router: balance_rel_threshold: float = 1.0001, eviction_interval_secs: int = 60, max_tree_size: int = 2**24, - max_payload_size: int = 4 * 1024 * 1024, # 4MB + max_payload_size: int = 256 * 1024 * 1024, # 256MB verbose: bool = False, log_dir: Optional[str] = None, service_discovery: bool = False, @@ -66,6 +70,9 @@ class Router: service_discovery_namespace: Optional[str] = None, prometheus_port: Optional[int] = None, prometheus_host: Optional[str] = None, + pd_disaggregated: bool = False, + prefill_urls: Optional[List[tuple]] = None, + decode_urls: Optional[List[str]] = None, ): if selector is None: selector = {} @@ -91,6 +98,9 @@ class Router: service_discovery_namespace=service_discovery_namespace, prometheus_port=prometheus_port, prometheus_host=prometheus_host, + pd_disaggregated=pd_disaggregated, + prefill_urls=prefill_urls, + decode_urls=decode_urls, ) def start(self) -> None: diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index c6f0444f4..26b3c33d9 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase): balance_rel_threshold=1.0001, eviction_interval=60, max_tree_size=2**24, - max_payload_size=4 * 1024 * 1024, # 4MB + max_payload_size=256 * 1024 * 1024, # 256MB verbose=False, log_dir=None, service_discovery=False, selector=None, service_discovery_port=80, service_discovery_namespace=None, + prometheus_port=None, + prometheus_host=None, + # PD-specific attributes + pd_disaggregated=False, + prefill=None, + decode=None, + # Keep worker_urls for regular mode + worker_urls=[], ) def create_router_args(self, **kwargs): @@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase): def test_launch_router_with_empty_worker_urls(self): args = self.create_router_args(worker_urls=[]) - self.run_router_process(args) + self.run_router_process(args) # Expected error def test_launch_router_with_service_discovery(self): # Test router startup with service discovery enabled but no selectors @@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase): ) self.run_router_process(args) + def test_launch_router_pd_mode_basic(self): + """Test basic PD router functionality without actually starting servers.""" + # This test just verifies the PD router can be created and configured + # without actually starting it (which would require real prefill/decode servers) + from sglang_router import Router + from sglang_router.launch_router import RouterArgs + from sglang_router_rs import PolicyType + + # Test RouterArgs parsing for PD mode + # Simulate the parsed args structure from argparse with action="append" + args = self.create_router_args( + pd_disaggregated=True, + policy="power_of_two", # PowerOfTwo is only valid in PD mode + prefill=[ + ["http://prefill1:8080", "9000"], + ["http://prefill2:8080", "none"], + ], + decode=[ + ["http://decode1:8081"], + ["http://decode2:8081"], + ], + worker_urls=[], # Empty for PD mode + ) + + router_args = RouterArgs.from_cli_args(args) + self.assertTrue(router_args.pd_disaggregated) + self.assertEqual(router_args.policy, "power_of_two") + self.assertEqual(len(router_args.prefill_urls), 2) + self.assertEqual(len(router_args.decode_urls), 2) + + # Verify the parsed URLs and bootstrap ports + self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000)) + self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None)) + self.assertEqual(router_args.decode_urls[0], "http://decode1:8081") + self.assertEqual(router_args.decode_urls[1], "http://decode2:8081") + + # Test Router creation in PD mode + router = Router( + worker_urls=[], # Empty for PD mode + pd_disaggregated=True, + prefill_urls=[ + ("http://prefill1:8080", 9000), + ("http://prefill2:8080", None), + ], + decode_urls=["http://decode1:8081", "http://decode2:8081"], + policy=PolicyType.CacheAware, + host="127.0.0.1", + port=3001, + ) + self.assertIsNotNone(router) + + def test_policy_validation(self): + """Test that policy validation works correctly for PD and regular modes.""" + from sglang_router.launch_router import RouterArgs, launch_router + + # Test 1: PowerOfTwo is only valid in PD mode + args = self.create_router_args( + pd_disaggregated=False, + policy="power_of_two", + worker_urls=["http://localhost:8000"], + ) + + # Should raise error + with self.assertRaises(ValueError) as cm: + launch_router(args) + self.assertIn( + "PowerOfTwo policy is only supported in PD disaggregated mode", + str(cm.exception), + ) + + # Test 2: RoundRobin is not valid in PD mode + args = self.create_router_args( + pd_disaggregated=True, + policy="round_robin", + prefill=[["http://prefill1:8080", "9000"]], + decode=[["http://decode1:8081"]], + worker_urls=[], + ) + + # Should raise error + with self.assertRaises(ValueError) as cm: + launch_router(args) + self.assertIn( + "RoundRobin policy is not supported in PD disaggregated mode", + str(cm.exception), + ) + + # Test 3: Valid combinations should not raise errors + # Regular mode with RoundRobin + args = self.create_router_args( + pd_disaggregated=False, + policy="round_robin", + worker_urls=["http://localhost:8000"], + ) + # This should not raise (though it may fail to connect) + + # PD mode with PowerOfTwo + args = self.create_router_args( + pd_disaggregated=True, + policy="power_of_two", + prefill=[["http://prefill1:8080", "9000"]], + decode=[["http://decode1:8081"]], + worker_urls=[], + ) + # This should not raise (though it may fail to connect) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 4915d3c52..439db1c4f 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -1,7 +1,11 @@ use pyo3::prelude::*; pub mod logging; use std::collections::HashMap; +pub mod openai_api_types; +pub mod pd_router; +pub mod pd_types; pub mod prometheus; +pub mod request_adapter; pub mod router; pub mod server; pub mod service_discovery; @@ -14,6 +18,7 @@ pub enum PolicyType { Random, RoundRobin, CacheAware, + PowerOfTwo, // Moved from PD-specific, now shared } #[pyclass] @@ -39,6 +44,12 @@ struct Router { service_discovery_namespace: Option, prometheus_port: Option, prometheus_host: Option, + request_timeout_secs: u64, + // PD mode flag + pd_disaggregated: bool, + // PD-specific fields (only used when pd_disaggregated is true) + prefill_urls: Option)>>, + decode_urls: Option>, } #[pymethods] @@ -56,7 +67,7 @@ impl Router { balance_rel_threshold = 1.0001, eviction_interval_secs = 60, max_tree_size = 2usize.pow(24), - max_payload_size = 4 * 1024 * 1024, + max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches verbose = false, log_dir = None, service_discovery = false, @@ -64,7 +75,11 @@ impl Router { service_discovery_port = 80, service_discovery_namespace = None, prometheus_port = None, - prometheus_host = None + prometheus_host = None, + request_timeout_secs = 600, // Add configurable request timeout + pd_disaggregated = false, // New flag for PD mode + prefill_urls = None, + decode_urls = None ))] fn new( worker_urls: Vec, @@ -87,6 +102,10 @@ impl Router { service_discovery_namespace: Option, prometheus_port: Option, prometheus_host: Option, + request_timeout_secs: u64, + pd_disaggregated: bool, + prefill_urls: Option)>>, + decode_urls: Option>, ) -> PyResult { Ok(Router { host, @@ -109,28 +128,75 @@ impl Router { service_discovery_namespace, prometheus_port, prometheus_host, + request_timeout_secs, + pd_disaggregated, + prefill_urls, + decode_urls, }) } fn start(&self) -> PyResult<()> { - let policy_config = match &self.policy { - PolicyType::Random => router::PolicyConfig::RandomConfig { + let policy_config = if self.pd_disaggregated { + // PD mode - map PolicyType to PDSelectionPolicy + let pd_selection_policy = match &self.policy { + PolicyType::Random => pd_types::PDSelectionPolicy::Random, + PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo, + PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware { + cache_threshold: self.cache_threshold, + balance_abs_threshold: self.balance_abs_threshold, + balance_rel_threshold: self.balance_rel_threshold, + }, + PolicyType::RoundRobin => { + return Err(pyo3::exceptions::PyValueError::new_err( + "RoundRobin policy is not supported in PD disaggregated mode", + )); + } + }; + + let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "PD disaggregated mode requires prefill_urls", + ) + })?; + let decode_urls = self.decode_urls.as_ref().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err( + "PD disaggregated mode requires decode_urls", + ) + })?; + + router::PolicyConfig::PrefillDecodeConfig { + selection_policy: pd_selection_policy, + prefill_urls: prefill_urls.clone(), + decode_urls: decode_urls.clone(), timeout_secs: self.worker_startup_timeout_secs, interval_secs: self.worker_startup_check_interval, - }, - PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval, - }, - PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { - timeout_secs: self.worker_startup_timeout_secs, - interval_secs: self.worker_startup_check_interval, - cache_threshold: self.cache_threshold, - balance_abs_threshold: self.balance_abs_threshold, - balance_rel_threshold: self.balance_rel_threshold, - eviction_interval_secs: self.eviction_interval_secs, - max_tree_size: self.max_tree_size, - }, + } + } else { + // Regular mode + match &self.policy { + PolicyType::Random => router::PolicyConfig::RandomConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, + PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, + PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + cache_threshold: self.cache_threshold, + balance_abs_threshold: self.balance_abs_threshold, + balance_rel_threshold: self.balance_rel_threshold, + eviction_interval_secs: self.eviction_interval_secs, + max_tree_size: self.max_tree_size, + }, + PolicyType::PowerOfTwo => { + return Err(pyo3::exceptions::PyValueError::new_err( + "PowerOfTwo policy is only supported in PD disaggregated mode", + )); + } + } }; // Create service discovery config if enabled @@ -166,6 +232,7 @@ impl Router { log_dir: self.log_dir.clone(), service_discovery_config, prometheus_config, + request_timeout_secs: self.request_timeout_secs, }) .await .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; diff --git a/sgl-router/src/openai_api_types.rs b/sgl-router/src/openai_api_types.rs new file mode 100644 index 000000000..808f8b46f --- /dev/null +++ b/sgl-router/src/openai_api_types.rs @@ -0,0 +1,704 @@ +// OpenAI-compatible API types for text generation +// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference +// Reference: Azure OpenAI API documentation which follows OpenAI's specification + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +/// Common trait for all generation requests +pub trait GenerationRequest: Send + Sync { + /// Check if the request is for streaming + fn is_stream(&self) -> bool; + + /// Get the model name if specified + fn get_model(&self) -> Option<&str>; + + /// Extract text content for routing decisions + fn extract_text_for_routing(&self) -> String; +} + +// ============= Completions API (v1/completions) - DEPRECATED but still supported ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionRequest { + /// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang) + pub model: String, + + /// The prompt(s) to generate completions for + pub prompt: StringOrArray, + + /// The suffix that comes after a completion of inserted text + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature (nucleus sampling) + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many completions to generate for each prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// Whether to stream back partial progress + #[serde(default)] + pub stream: bool, + + /// Include the log probabilities on the logprobs most likely tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + + /// Echo back the prompt in addition to the completion + #[serde(default)] + pub echo: bool, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Generates best_of completions server-side and returns the "best" + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, +} + +impl GenerationRequest for CompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + match &self.prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + } + } +} + +// ============= Chat Completions API (v1/chat/completions) ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionRequest { + /// ID of the model to use + pub model: String, + + /// A list of messages comprising the conversation so far + pub messages: Vec, + + /// What sampling temperature to use, between 0 and 2 + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + + /// An alternative to sampling with temperature + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + + /// How many chat completion choices to generate for each input message + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + + /// If set, partial message deltas will be sent + #[serde(default)] + pub stream: bool, + + /// Up to 4 sequences where the API will stop generating further tokens + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + + /// The maximum number of tokens to generate + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + + /// An upper bound for the number of tokens that can be generated for a completion + #[serde(skip_serializing_if = "Option::is_none")] + pub max_completion_tokens: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + + /// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + + /// Modify the likelihood of specified tokens appearing in the completion + #[serde(skip_serializing_if = "Option::is_none")] + pub logit_bias: Option>, + + /// A unique identifier representing your end-user + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + + /// If specified, our system will make a best effort to sample deterministically + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + /// Whether to return log probabilities of the output tokens + #[serde(default)] + pub logprobs: bool, + + /// An integer between 0 and 20 specifying the number of most likely tokens to return + #[serde(skip_serializing_if = "Option::is_none")] + pub top_logprobs: Option, + + /// An object specifying the format that the model must output + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + + /// A list of tools the model may call + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Controls which (if any) tool is called by the model + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + + /// Whether to enable parallel function calling during tool use + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + + /// Deprecated: use tools instead + #[serde(skip_serializing_if = "Option::is_none")] + pub functions: Option>, + + /// Deprecated: use tool_choice instead + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ChatMessage { + System { + role: String, // "system" + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + User { + role: String, // "user" + content: UserMessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + Assistant { + role: String, // "assistant" + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + function_call: Option, + }, + Tool { + role: String, // "tool" + content: String, + tool_call_id: String, + }, + Function { + role: String, // "function" + content: String, + name: String, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum UserMessageContent { + Text(String), + Parts(Vec), +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { image_url: ImageUrl }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, // "auto", "low", or "high" +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ResponseFormat { + #[serde(rename = "text")] + Text, + #[serde(rename = "json_object")] + JsonObject, + #[serde(rename = "json_schema")] + JsonSchema { json_schema: JsonSchemaFormat }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct JsonSchemaFormat { + pub name: String, + pub schema: Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Tool { + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: Function, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Function { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub parameters: Value, // JSON Schema +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ToolChoice { + None, + Auto, + Required, + Function { + #[serde(rename = "type")] + tool_type: String, // "function" + function: FunctionChoice, + }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionChoice { + pub name: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: String, // "function" + pub function: FunctionCallResponse, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum FunctionCall { + None, + Auto, + Function { name: String }, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallResponse { + pub name: String, + pub arguments: String, // JSON string +} + +impl GenerationRequest for ChatCompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + Some(&self.model) + } + + fn extract_text_for_routing(&self) -> String { + // Extract text from messages for routing decisions + self.messages + .iter() + .filter_map(|msg| match msg { + ChatMessage::System { content, .. } => Some(content.clone()), + ChatMessage::User { content, .. } => match content { + UserMessageContent::Text(text) => Some(text.clone()), + UserMessageContent::Parts(parts) => { + let texts: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + .collect(); + Some(texts.join(" ")) + } + }, + ChatMessage::Assistant { content, .. } => content.clone(), + ChatMessage::Tool { content, .. } => Some(content.clone()), + ChatMessage::Function { content, .. } => Some(content.clone()), + }) + .collect::>() + .join(" ") + } +} + +// ============= Generate API (/generate) ============= + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GenerateRequest { + /// The prompt to generate from (OpenAI style) + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, + + /// Text input - SGLang native format + #[serde(skip_serializing_if = "Option::is_none")] + pub text: Option, + + /// Input IDs for tokenized input + #[serde(skip_serializing_if = "Option::is_none")] + pub input_ids: Option, + + /// Generation parameters + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parameters: Option, + + /// Sampling parameters (sglang style) + #[serde(skip_serializing_if = "Option::is_none")] + pub sampling_params: Option, + + /// Whether to stream the response + #[serde(default)] + pub stream: bool, + + /// Whether to return logprobs + #[serde(default)] + pub return_logprob: bool, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum InputIds { + Single(Vec), + Batch(Vec>), +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct GenerateParameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub best_of: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub decoder_input_details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub details: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub do_sample: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub return_full_text: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub truncate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub typical_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub watermark: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] +pub struct SamplingParams { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_new_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repetition_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub ignore_eos: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub skip_special_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, +} + +impl GenerationRequest for GenerateRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_model(&self) -> Option<&str> { + // Generate requests typically don't have a model field + None + } + + fn extract_text_for_routing(&self) -> String { + // Check fields in priority order: text, prompt, inputs + if let Some(ref text) = self.text { + return text.clone(); + } + + if let Some(ref prompt) = self.prompt { + return match prompt { + StringOrArray::String(s) => s.clone(), + StringOrArray::Array(v) => v.join(" "), + }; + } + + if let Some(ref input_ids) = self.input_ids { + return match input_ids { + InputIds::Single(ids) => ids + .iter() + .map(|&id| id.to_string()) + .collect::>() + .join(" "), + InputIds::Batch(batches) => batches + .iter() + .flat_map(|batch| batch.iter().map(|&id| id.to_string())) + .collect::>() + .join(" "), + }; + } + + // No text input found + String::new() + } +} + +// ============= Helper Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum StringOrArray { + String(String), + Array(Vec), +} + +// ============= Response Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "content_filter", etc. +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct LogProbs { + pub tokens: Vec, + pub token_logprobs: Vec>, + pub top_logprobs: Vec>>, + pub text_offset: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionResponse { + pub id: String, + pub object: String, // "chat.completion" + pub created: u64, + pub model: String, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatChoice { + pub index: u32, + pub message: ChatMessage, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, // "stop", "length", "tool_calls", "content_filter", "function_call" +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbs { + pub content: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatLogProbsContent { + pub token: String, + pub logprob: f32, + pub bytes: Option>, + pub top_logprobs: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct TopLogProb { + pub token: String, + pub logprob: f32, + pub bytes: Option>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Usage { + pub prompt_tokens: u32, + pub completion_tokens: u32, + pub total_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_tokens_details: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionTokensDetails { + pub reasoning_tokens: Option, +} + +// ============= Streaming Response Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamResponse { + pub id: String, + pub object: String, // "text_completion" + pub created: u64, + pub choices: Vec, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct CompletionStreamChoice { + pub text: String, + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatCompletionStreamResponse { + pub id: String, + pub object: String, // "chat.completion.chunk" + pub created: u64, + pub model: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_fingerprint: Option, + pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatStreamChoice { + pub index: u32, + pub delta: ChatMessageDelta, + #[serde(skip_serializing_if = "Option::is_none")] + pub logprobs: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ChatMessageDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub function_call: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ToolCallDelta { + pub index: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "type")] + pub tool_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub function: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct FunctionCallDelta { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, +} + +// ============= Error Response Types ============= + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorResponse { + pub error: ErrorDetail, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ErrorDetail { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub param: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} diff --git a/sgl-router/src/pd_router.rs b/sgl-router/src/pd_router.rs new file mode 100644 index 000000000..e06fa371a --- /dev/null +++ b/sgl-router/src/pd_router.rs @@ -0,0 +1,1002 @@ +// PD (Prefill-Decode) Router Implementation +// This module handles routing for disaggregated prefill-decode systems + +use crate::pd_types::{Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDSelectionPolicy}; +use crate::tree::Tree; +use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; +use actix_web::{HttpRequest, HttpResponse}; +use futures_util::{StreamExt, TryStreamExt}; +use metrics::{counter, histogram}; +use serde_json::Value; +use std::collections::HashMap; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Mutex, RwLock}; +use std::time::{Duration, Instant}; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; + +// Removed over-engineered ProxyResponse - using HttpResponse directly + +#[derive(Debug)] +pub struct PDRouter { + pub prefill_workers: Arc>>, + pub decode_workers: Arc>>, + pub selection_policy: PDSelectionPolicy, + pub load_tracking: Arc>>, + pub prefill_tree: Option>>, + pub timeout_secs: u64, + pub interval_secs: u64, + pub worker_loads: Arc>>, + pub load_monitor_handle: Option>>, + pub http_client: reqwest::Client, +} + +// RAII guard for load tracking to ensure cleanup even on panic +struct LoadGuard<'a> { + tracking: &'a Arc>>, + urls: Vec, +} + +impl<'a> LoadGuard<'a> { + fn new( + tracking: &'a Arc>>, + urls: Vec, + ) -> Self { + // Increment counters + for url in &urls { + let counter = tracking + .entry(url.clone()) + .or_insert_with(|| Arc::new(AtomicUsize::new(0))); + counter.fetch_add(1, Ordering::Relaxed); + } + LoadGuard { tracking, urls } + } +} + +impl Drop for LoadGuard<'_> { + fn drop(&mut self) { + // Guaranteed cleanup even on panic + for url in &self.urls { + if let Some(counter) = self.tracking.get(url) { + counter.fetch_sub(1, Ordering::Relaxed); + } + } + } +} + +impl PDRouter { + // TODO: Add methods for dynamic worker management to support /register endpoint: + // - add_prefill_server(url: String, bootstrap_port: Option) + // - add_decode_server(url: String) + // - remove_prefill_server(url: &str) + // - remove_decode_server(url: &str) + // These methods will be used when service discovery is implemented for PD mode + + pub fn new( + prefill_urls: Vec<(String, Option)>, + decode_urls: Vec, + selection_policy: PDSelectionPolicy, + timeout_secs: u64, + interval_secs: u64, + ) -> Result { + // Convert URLs to EngineInfo + let prefill_workers: Vec = prefill_urls + .into_iter() + .map(|(url, port)| EngineInfo::new_prefill(url, port)) + .collect(); + + let decode_workers: Vec = decode_urls + .into_iter() + .map(EngineInfo::new_decode) + .collect(); + + // Wait for PD workers to be healthy + let all_urls: Vec = prefill_workers + .iter() + .chain(decode_workers.iter()) + .map(|engine| engine.url.clone()) + .collect(); + crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?; + + // Initialize load tracking with atomic counters + let load_tracking = Arc::new(dashmap::DashMap::new()); + for engine in &prefill_workers { + load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0))); + } + for engine in &decode_workers { + load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0))); + } + + // Initialize cache-aware components if needed + let prefill_tree = match &selection_policy { + PDSelectionPolicy::CacheAware { .. } => { + let tree = Arc::new(Mutex::new(Tree::new())); + // Initialize tree with prefill workers + for engine in &prefill_workers { + tree.lock().unwrap().insert("", &engine.url); + } + Some(tree) + } + _ => None, + }; + + // Set up background load monitoring for power-of-two selection + let (tx, rx) = tokio::sync::watch::channel(HashMap::new()); + let worker_loads = Arc::new(rx); + + // Create a shared HTTP client for all operations + let http_client = reqwest::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; + + let load_monitor_handle = if matches!(selection_policy, PDSelectionPolicy::PowerOfTwo) { + let monitor_urls = all_urls.clone(); + let monitor_interval = interval_secs; + let monitor_client = http_client.clone(); + + Some(Arc::new(tokio::spawn(async move { + Self::monitor_worker_loads_with_client( + monitor_urls, + tx, + monitor_interval, + monitor_client, + ) + .await; + }))) + } else { + None + }; + + Ok(PDRouter { + prefill_workers: Arc::new(RwLock::new(prefill_workers)), + decode_workers: Arc::new(RwLock::new(decode_workers)), + selection_policy, + load_tracking, + prefill_tree, + timeout_secs, + interval_secs, + worker_loads, + load_monitor_handle, + http_client, + }) + } + + // Route a typed generate request + pub async fn route_generate( + &self, + client: &reqwest::Client, + req: &HttpRequest, + mut typed_req: GenerateReqInput, + route: &str, + ) -> HttpResponse { + let start = Instant::now(); + let _request_id = Uuid::new_v4(); + + // Get stream flag and return_logprob flag before moving the request + let is_stream = typed_req.is_stream(); + let return_logprob = typed_req + .other + .get("return_logprob") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + // Select servers + let (prefill, decode) = match self.select_pd_pair(client).await { + Ok(pair) => pair, + Err(e) => { + error!("Failed to select PD pair: {}", e); + counter!("sgl_router_pd_errors_total", "error" => "server_selection").increment(1); + return HttpResponse::ServiceUnavailable() + .body(format!("No available servers: {}", e)); + } + }; + + // Log routing decision + info!( + "PD routing: {} -> prefill={}, decode={}", + route, prefill.url, decode.url + ); + + // Add bootstrap info using the trait method + if let Err(e) = typed_req.add_bootstrap_info(&prefill) { + error!("Failed to add bootstrap info: {}", e); + counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); + return HttpResponse::InternalServerError() + .body(format!("Bootstrap injection failed: {}", e)); + } + + // Convert to JSON after bootstrap injection + let json_with_bootstrap = match serde_json::to_value(&typed_req) { + Ok(json) => json, + Err(e) => { + error!("Failed to serialize request: {}", e); + return HttpResponse::InternalServerError().body("Failed to serialize request"); + } + }; + + // Execute dual dispatch + self.execute_dual_dispatch( + client, + req, + json_with_bootstrap, + route, + &prefill, + &decode, + is_stream, + return_logprob, + start, + ) + .await + } + + // Route a typed chat request + pub async fn route_chat( + &self, + client: &reqwest::Client, + req: &HttpRequest, + mut typed_req: ChatReqInput, + route: &str, + ) -> HttpResponse { + let start = Instant::now(); + + // Get stream flag and return_logprob flag before moving the request + let is_stream = typed_req.is_stream(); + let return_logprob = typed_req + .other + .get("return_logprob") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + // Select servers + let (prefill, decode) = match self.select_pd_pair(client).await { + Ok(pair) => pair, + Err(e) => { + error!("Failed to select PD pair: {}", e); + counter!("sgl_router_pd_errors_total", "error" => "server_selection").increment(1); + return HttpResponse::ServiceUnavailable() + .body(format!("No available servers: {}", e)); + } + }; + + // Log routing decision + info!( + "PD routing: {} -> prefill={}, decode={}", + route, prefill.url, decode.url + ); + + // Add bootstrap info using the trait method + if let Err(e) = typed_req.add_bootstrap_info(&prefill) { + error!("Failed to add bootstrap info: {}", e); + counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1); + return HttpResponse::InternalServerError() + .body(format!("Bootstrap injection failed: {}", e)); + } + + // Convert to JSON after bootstrap injection + let json_with_bootstrap = match serde_json::to_value(&typed_req) { + Ok(json) => json, + Err(e) => { + error!("Failed to serialize request: {}", e); + return HttpResponse::InternalServerError().body("Failed to serialize request"); + } + }; + + // Execute dual dispatch + self.execute_dual_dispatch( + client, + req, + json_with_bootstrap, + route, + &prefill, + &decode, + is_stream, + return_logprob, + start, + ) + .await + } + + // Execute the dual dispatch to prefill and decode servers + #[allow(clippy::too_many_arguments)] + async fn execute_dual_dispatch( + &self, + client: &reqwest::Client, + req: &HttpRequest, + json_request: serde_json::Value, + route: &str, + prefill: &EngineInfo, + decode: &EngineInfo, + is_stream: bool, + return_logprob: bool, + start_time: Instant, + ) -> HttpResponse { + // Update load tracking for both workers + let _guard = LoadGuard::new( + &self.load_tracking, + vec![prefill.url.clone(), decode.url.clone()], + ); + + // Build requests using .json() method + let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request); + + let mut decode_request = client.post(decode.api_path(route)).json(&json_request); + + // Copy headers from original request + for (name, value) in crate::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { + prefill_request = prefill_request.header(&name, &value); + decode_request = decode_request.header(&name, &value); + } + } + + // Send both requests concurrently + let (prefill_result, decode_result) = + tokio::join!(prefill_request.send(), decode_request.send()); + + // Update metrics + let duration = start_time.elapsed(); + histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1); + counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string()) + .increment(1); + counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string()) + .increment(1); + + // Process decode response + match decode_result { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + + if !status.is_success() { + counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1); + error!( + "Decode server {} returned error status: {}", + decode.url, status + ); + + // Return the error response from decode server + match res.bytes().await { + Ok(error_body) => { + return HttpResponse::build(status).body(error_body.to_vec()); + } + Err(e) => { + return HttpResponse::build(status) + .body(format!("Decode server error: {}", e)); + } + } + } + + // Log prefill errors for debugging + if let Err(e) = &prefill_result { + error!( + "Prefill server {} failed (non-critical): {}", + prefill.url, e + ); + counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1); + } + + if is_stream { + // Streaming response + if return_logprob { + // Get prefill logprobs for merging + let prefill_logprobs = + match prefill_result { + Ok(prefill_res) => match prefill_res.bytes().await { + Ok(body) => serde_json::from_slice::(&body) + .ok() + .and_then(|json| { + json.pointer("/meta_info/input_token_logprobs").cloned() + }), + Err(_) => None, + }, + Err(_) => None, + }; + + // Stream with logprob merging + HttpResponse::build(status) + .insert_header(( + CONTENT_TYPE, + HeaderValue::from_static("text/event-stream"), + )) + .streaming(res.bytes_stream().map(move |chunk_result| { + match chunk_result { + Ok(chunk) => { + // Try to merge logprobs + if let Ok(merged) = Self::merge_streaming_logprobs( + prefill_logprobs.clone(), + &chunk, + ) { + Ok(merged) + } else { + Ok(chunk) + } + } + Err(e) => Err(actix_web::error::ErrorInternalServerError( + format!("Stream error: {}", e), + )), + } + })) + } else { + // No logprob merging needed + HttpResponse::build(status) + .insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream"))) + .streaming({ + let decode_url = decode.url.clone(); + res.bytes_stream().map_err(move |e| { + error!("Stream error from decode server {}: {}", decode_url, e); + counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1); + actix_web::error::ErrorInternalServerError(format!("Stream error: {}", e)) + }) + }) + } + } else { + // Non-streaming response + match res.bytes().await { + Ok(decode_body) => { + if return_logprob { + self.merge_logprobs(prefill_result, decode_body, status) + .await + } else { + HttpResponse::build(status).body(decode_body.to_vec()) + } + } + Err(e) => { + error!("Failed to read decode response: {}", e); + HttpResponse::InternalServerError().body("Failed to read response") + } + } + } + } + Err(e) => { + error!("Decode request failed: {}", e); + counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()) + .increment(1); + HttpResponse::BadGateway().body(format!("Decode server error: {}", e)) + } + } + } + + // Merge logprobs from prefill and decode responses + async fn merge_logprobs( + &self, + prefill_result: Result, + decode_body: bytes::Bytes, + status: actix_web::http::StatusCode, + ) -> HttpResponse { + match prefill_result { + Ok(prefill_res) => { + match prefill_res.bytes().await { + Ok(prefill_body) => { + match ( + serde_json::from_slice::(&prefill_body), + serde_json::from_slice::(&decode_body), + ) { + (Ok(prefill_json), Ok(mut decode_json)) => { + // Merge input_token_logprobs + if let (Some(prefill_meta), Some(decode_meta)) = ( + prefill_json.get("meta_info"), + decode_json.get_mut("meta_info"), + ) { + if let (Some(prefill_logprobs), Some(decode_logprobs)) = ( + prefill_meta.get("input_token_logprobs"), + decode_meta.get_mut("input_token_logprobs"), + ) { + if let (Some(p_arr), Some(d_arr)) = ( + prefill_logprobs.as_array(), + decode_logprobs.as_array(), + ) { + let mut merged = p_arr.clone(); + merged.extend(d_arr.clone()); + decode_meta["input_token_logprobs"] = + Value::Array(merged); + } + } + } + HttpResponse::build(status).json(&decode_json) + } + _ => { + warn!("Failed to parse responses for logprob merging"); + HttpResponse::build(status).body(decode_body.to_vec()) + } + } + } + Err(e) => { + warn!("Failed to read prefill response: {}", e); + HttpResponse::build(status).body(decode_body.to_vec()) + } + } + } + Err(_) => HttpResponse::build(status).body(decode_body.to_vec()), + } + } + + // Select a pair of prefill and decode servers + async fn select_pd_pair( + &self, + _client: &reqwest::Client, + ) -> Result<(EngineInfo, EngineInfo), String> { + // Check we have workers + if self + .prefill_workers + .read() + .map_err(|e| format!("Failed to acquire prefill workers lock: {}", e))? + .is_empty() + { + return Err("No prefill workers available. Please check if prefill servers are configured and healthy.".to_string()); + } + if self + .decode_workers + .read() + .map_err(|e| format!("Failed to acquire decode workers lock: {}", e))? + .is_empty() + { + return Err("No decode workers available. Please check if decode servers are configured and healthy.".to_string()); + } + + match &self.selection_policy { + PDSelectionPolicy::Random => self.select_random(), + PDSelectionPolicy::PowerOfTwo => self.select_power_of_two().await, + PDSelectionPolicy::CacheAware { .. } => { + // TODO: Implement cache-aware selection + self.select_power_of_two().await + } + } + } + + fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> { + let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; + let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; + + let prefill = prefill_list[rand::random::() % prefill_list.len()].clone(); + let decode = decode_list[rand::random::() % decode_list.len()].clone(); + + Ok((prefill, decode)) + } + + async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> { + let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?; + let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?; + + let (p1_idx, p2_idx) = get_two_random_indices(prefill_list.len()); + let (d1_idx, d2_idx) = get_two_random_indices(decode_list.len()); + + let loads = self.worker_loads.borrow(); + + let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0); + let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0); + let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0); + let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0); + + info!( + "Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}", + prefill_list[p1_idx].url, + p1_load, + prefill_list[p2_idx].url, + p2_load, + decode_list[d1_idx].url, + d1_load, + decode_list[d2_idx].url, + d2_load + ); + + let selected_prefill = if p1_load <= p2_load { + prefill_list[p1_idx].clone() + } else { + prefill_list[p2_idx].clone() + }; + + let selected_decode = if d1_load <= d2_load { + decode_list[d1_idx].clone() + } else { + decode_list[d2_idx].clone() + }; + + Ok((selected_prefill, selected_decode)) + } + + // Background task to monitor worker loads with shared client + async fn monitor_worker_loads_with_client( + worker_urls: Vec, + tx: tokio::sync::watch::Sender>, + interval_secs: u64, + client: reqwest::Client, + ) { + loop { + let mut loads = HashMap::new(); + + let futures: Vec<_> = worker_urls + .iter() + .map(|url| { + let client = client.clone(); + let url = url.clone(); + async move { + let load = get_worker_load(&client, &url).await.unwrap_or(0); + (url, load) + } + }) + .collect(); + + let results = futures_util::future::join_all(futures).await; + + for (url, load) in results { + loads.insert(url, load); + } + + debug!("Worker loads updated: {:?}", loads); + + // Check if receiver is still active + if tx.send(loads).is_err() { + info!("Load monitor receiver dropped, shutting down monitor task"); + break; + } + + tokio::time::sleep(Duration::from_secs(interval_secs)).await; + } + } + + // Simple helper to merge logprobs in streaming responses + fn merge_streaming_logprobs( + prefill_logprobs: Option, + decode_chunk: &[u8], + ) -> Result { + // Skip non-data chunks + let chunk_str = std::str::from_utf8(decode_chunk).map_err(|_| ())?; + if !chunk_str.starts_with("data: ") || chunk_str.contains("[DONE]") { + return Err(()); + } + + // Parse JSON from chunk + let json_str = chunk_str.trim_start_matches("data: ").trim(); + let mut decode_json: Value = serde_json::from_str(json_str).map_err(|_| ())?; + + // Merge prefill logprobs if available + if let Some(ref p_logprobs) = prefill_logprobs { + if let Some(meta) = decode_json.get_mut("meta_info") { + if let Some(d_logprobs) = meta.get_mut("input_token_logprobs") { + if let (Some(p_arr), Some(d_arr)) = + (p_logprobs.as_array(), d_logprobs.as_array()) + { + let mut merged = p_arr.clone(); + merged.extend(d_arr.clone()); + *d_logprobs = Value::Array(merged); + } + } + } + } + + // Re-serialize + let merged_str = format!( + "data: {}\n\n", + serde_json::to_string(&decode_json).unwrap_or_default() + ); + Ok(bytes::Bytes::from(merged_str)) + } +} + +// Helper functions +fn get_two_random_indices(len: usize) -> (usize, usize) { + if len == 1 { + (0, 0) + } else { + let idx1 = rand::random::() % len; + let mut idx2 = rand::random::() % len; + while idx2 == idx1 { + idx2 = rand::random::() % len; + } + (idx1, idx2) + } +} + +async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option { + match client.get(format!("{}/get_load", worker_url)).send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(data) => data + .get("load") + .and_then(|v| v.as_i64()) + .map(|v| v as isize), + Err(e) => { + debug!("Failed to parse load response from {}: {}", worker_url, e); + None + } + }, + Err(e) => { + debug!("Failed to read load response from {}: {}", worker_url, e); + None + } + }, + Ok(res) => { + debug!( + "Worker {} returned non-success status: {}", + worker_url, + res.status() + ); + None + } + Err(e) => { + debug!("Failed to get load from {}: {}", worker_url, e); + None + } + } +} + +// PD-specific endpoints +impl PDRouter { + pub async fn health_generate(&self, client: &reqwest::Client) -> HttpResponse { + let mut all_healthy = true; + let mut unhealthy_servers = Vec::new(); + + // Collect all worker URLs with their types + let mut worker_infos = Vec::new(); + + for worker in self.prefill_workers.read().unwrap().iter() { + worker_infos.push((worker.url.clone(), "prefill")); + } + + for worker in self.decode_workers.read().unwrap().iter() { + worker_infos.push((worker.url.clone(), "decode")); + } + + // Create tasks with URL tracking + let tasks: Vec<_> = worker_infos + .iter() + .map(|(url, _)| { + let health_url = format!("{}/health_generate", url); + client.get(&health_url).send() + }) + .collect(); + + let results = futures_util::future::join_all(tasks).await; + + for ((url, worker_type), result) in worker_infos.iter().zip(results.into_iter()) { + match result { + Ok(res) if res.status().is_success() => { + debug!("Health check passed for {} server: {}", worker_type, url); + } + Ok(res) => { + all_healthy = false; + let msg = format!( + "{} server {} returned status {}", + worker_type, + url, + res.status() + ); + error!("{}", msg); + unhealthy_servers.push(msg); + } + Err(e) => { + all_healthy = false; + let msg = format!("{} server {} error: {}", worker_type, url, e); + error!("{}", msg); + unhealthy_servers.push(msg); + } + } + } + + if all_healthy { + HttpResponse::Ok().body("Health check passed on all servers") + } else { + HttpResponse::ServiceUnavailable() + .body(format!("Health check failed: {:?}", unhealthy_servers)) + } + } + + pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse { + // Get info from all decode servers (where generation happens) + let mut all_internal_states = Vec::new(); + let mut decode_infos = Vec::new(); + + // Clone URLs to avoid holding lock across await + let worker_urls: Vec = self + .decode_workers + .read() + .unwrap() + .iter() + .map(|w| w.url.clone()) + .collect(); + + for worker_url in worker_urls { + match client + .get(format!("{}/get_server_info", worker_url)) + .send() + .await + { + Ok(res) if res.status().is_success() => { + match res.json::().await { + Ok(info) => { + // Extract internal_states from each decode server + if let Some(states) = info.get("internal_states") { + if let Some(states_array) = states.as_array() { + all_internal_states.extend(states_array.clone()); + } + } + decode_infos.push(info); + } + Err(e) => error!("Failed to parse server info: {}", e), + } + } + _ => {} + } + } + + // If we have internal states, return in the format expected by bench_one_batch_server.py + if !all_internal_states.is_empty() { + // Use the first decode server's internal state (they should all be similar) + HttpResponse::Ok().json(serde_json::json!({ + "internal_states": all_internal_states, + // Include original format for compatibility + "decode_servers": decode_infos, + })) + } else { + // Fallback: create a dummy internal_states entry + HttpResponse::Ok().json(serde_json::json!({ + "internal_states": [{ + "last_gen_throughput": 0.0, + "avg_spec_accept_length": null, + }], + "decode_servers": decode_infos, + })) + } + } + + pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse { + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url.clone()) + } else { + return HttpResponse::InternalServerError().body("Failed to access prefill workers"); + }; + + if let Some(worker_url) = first_worker_url { + // Send request directly without going through Router + let mut request_builder = client.get(format!("{}/v1/models", worker_url)); + for (name, value) in crate::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to send request: {}", e)), + } + } else { + HttpResponse::ServiceUnavailable().body("No prefill servers available") + } + } + + pub async fn get_loads(&self, client: &reqwest::Client) -> HttpResponse { + let p_urls: Vec<_> = self + .prefill_workers + .read() + .unwrap() + .iter() + .map(|w| w.url.clone()) + .collect(); + let d_urls: Vec<_> = self + .decode_workers + .read() + .unwrap() + .iter() + .map(|w| w.url.clone()) + .collect(); + + let mut prefill_loads = Vec::new(); + let mut decode_loads = Vec::new(); + + for url in &p_urls { + let load = get_worker_load(client, url).await.unwrap_or(-1); + prefill_loads.push(serde_json::json!({ + "engine": format!("(Prefill@{})", url), + "load": load as i64 + })); + } + + for url in &d_urls { + let load = get_worker_load(client, url).await.unwrap_or(-1); + decode_loads.push(serde_json::json!({ + "engine": format!("(Decode@{})", url), + "load": load as i64 + })); + } + + HttpResponse::Ok().json(serde_json::json!({ + "prefill": prefill_loads, + "decode": decode_loads + })) + } + + pub async fn get_model_info( + &self, + client: &reqwest::Client, + req: &HttpRequest, + ) -> HttpResponse { + // Get model info from the first prefill server (matches original Rust PDLB behavior) + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url.clone()) + } else { + return HttpResponse::InternalServerError().body("Failed to access prefill workers"); + }; + + if let Some(worker_url) = first_worker_url { + let mut request_builder = client.get(format!("{}/get_model_info", worker_url)); + for (name, value) in crate::router::copy_request_headers(req) { + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } + } + match request_builder.send().await { + Ok(res) => { + let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); + match res.bytes().await { + Ok(body) => HttpResponse::build(status).body(body.to_vec()), + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to read response body: {}", e)), + } + } + Err(e) => HttpResponse::InternalServerError() + .body(format!("Failed to send request: {}", e)), + } + } else { + HttpResponse::ServiceUnavailable().body("No prefill servers available") + } + } + + pub async fn flush_cache(&self, client: &reqwest::Client) -> HttpResponse { + let mut tasks = Vec::new(); + + // Flush cache on all prefill servers + for worker in self.prefill_workers.read().unwrap().iter() { + let url = format!("{}/flush_cache", worker.url); + tasks.push(client.post(&url).send()); + } + + // Flush cache on all decode servers + for worker in self.decode_workers.read().unwrap().iter() { + let url = format!("{}/flush_cache", worker.url); + tasks.push(client.post(&url).send()); + } + + let results = futures_util::future::join_all(tasks).await; + + let mut all_success = true; + for (i, result) in results.into_iter().enumerate() { + match result { + Ok(res) if res.status().is_success() => {} + Ok(res) => { + all_success = false; + warn!( + "Server {} returned status {} for flush_cache", + i, + res.status() + ); + } + Err(e) => { + all_success = false; + error!("Server {} error during flush_cache: {}", i, e); + } + } + } + + if all_success { + HttpResponse::Ok().body("Cache flushed on all servers") + } else { + HttpResponse::InternalServerError().body("Cache flush failed on one or more servers") + } + } +} diff --git a/sgl-router/src/pd_types.rs b/sgl-router/src/pd_types.rs new file mode 100644 index 000000000..98b104386 --- /dev/null +++ b/sgl-router/src/pd_types.rs @@ -0,0 +1,245 @@ +// Essential PDLB types extracted for PD routing + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone)] +pub enum EngineType { + Prefill, + Decode, +} + +#[derive(Debug, Clone)] +pub struct EngineInfo { + pub engine_type: EngineType, + pub url: String, + pub bootstrap_port: Option, +} + +impl EngineInfo { + pub fn new_prefill(url: String, bootstrap_port: Option) -> Self { + EngineInfo { + engine_type: EngineType::Prefill, + url, + bootstrap_port, + } + } + + pub fn new_decode(url: String) -> Self { + EngineInfo { + engine_type: EngineType::Decode, + url, + bootstrap_port: None, + } + } + + pub fn api_path(&self, api_path: &str) -> String { + if api_path.starts_with("/") { + format!("{}{}", self.url, api_path) + } else { + format!("{}/{}", self.url, api_path) + } + } + + pub fn get_hostname(&self) -> String { + // Simple hostname extraction without external dependencies + let url = self + .url + .trim_start_matches("http://") + .trim_start_matches("https://"); + url.split(':').next().unwrap_or("localhost").to_string() + } +} + +// PD-specific routing policies +#[derive(Debug, Clone, PartialEq)] +pub enum PDSelectionPolicy { + Random, + PowerOfTwo, + CacheAware { + cache_threshold: f32, + balance_abs_threshold: usize, + balance_rel_threshold: f32, + }, +} +// Bootstrap types from PDLB +#[derive(Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum SingleOrBatch { + Single(T), + Batch(Vec), +} + +pub type InputIds = SingleOrBatch>; +pub type InputText = SingleOrBatch; +pub type BootstrapHost = SingleOrBatch; +pub type BootstrapPort = SingleOrBatch>; +pub type BootstrapRoom = SingleOrBatch; + +// Bootstrap trait for request handling +pub trait Bootstrap: Send + Sync { + fn is_stream(&self) -> bool; + fn get_batch_size(&self) -> Result, String>; + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ); + + fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> { + let batch_size = self.get_batch_size()?; + if let Some(batch_size) = batch_size { + self.set_bootstrap_info( + BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]), + BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]), + // Use high-quality random numbers to minimize collision risk + BootstrapRoom::Batch( + (0..batch_size) + .map(|_| { + // Combine multiple sources of randomness for better distribution + let r1 = rand::random::(); + let r2 = rand::random::(); + r1.wrapping_add(r2.rotate_left(32)) + }) + .collect(), + ), + ); + } else { + self.set_bootstrap_info( + BootstrapHost::Single(prefill_info.get_hostname()), + BootstrapPort::Single(prefill_info.bootstrap_port), + BootstrapRoom::Single({ + // Use high-quality random number for single requests too + let r1 = rand::random::(); + let r2 = rand::random::(); + r1.wrapping_add(r2.rotate_left(32)) + }), + ); + } + Ok(()) + } +} + +// Request types +#[derive(Debug, Deserialize, Serialize)] +pub struct GenerateReqInput { + pub text: Option, + pub input_ids: Option, + #[serde(default)] + pub stream: bool, + pub bootstrap_host: Option, + pub bootstrap_port: Option, + pub bootstrap_room: Option, + + #[serde(flatten)] + pub other: Value, +} + +impl GenerateReqInput { + pub fn get_batch_size(&self) -> Result, String> { + if self.text.is_some() && self.input_ids.is_some() { + return Err("Both text and input_ids are present in the request".to_string()); + } + + // Check text batch + if let Some(InputText::Batch(texts)) = &self.text { + if texts.is_empty() { + return Err("Batch text array is empty".to_string()); + } + if texts.len() > 10000 { + // Reasonable limit for production + return Err(format!( + "Batch size {} exceeds maximum allowed (10000)", + texts.len() + )); + } + return Ok(Some(texts.len())); + } + + // Check input_ids batch + if let Some(InputIds::Batch(ids)) = &self.input_ids { + if ids.is_empty() { + return Err("Batch input_ids array is empty".to_string()); + } + if ids.len() > 10000 { + // Reasonable limit for production + return Err(format!( + "Batch size {} exceeds maximum allowed (10000)", + ids.len() + )); + } + // Validate each sequence is not empty + for (i, seq) in ids.iter().enumerate() { + if seq.is_empty() { + return Err(format!("Input sequence at index {} is empty", i)); + } + } + return Ok(Some(ids.len())); + } + + Ok(None) + } +} + +impl Bootstrap for GenerateReqInput { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_batch_size(&self) -> Result, String> { + self.get_batch_size() + } + + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ) { + self.bootstrap_host = Some(bootstrap_host); + self.bootstrap_port = Some(bootstrap_port); + self.bootstrap_room = Some(bootstrap_room); + } +} + +#[derive(Debug, Deserialize, Serialize)] +pub struct ChatReqInput { + #[serde(default)] + pub stream: bool, + pub bootstrap_host: Option, + pub bootstrap_port: Option, + pub bootstrap_room: Option, + + #[serde(flatten)] + pub other: Value, +} + +impl Bootstrap for ChatReqInput { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_batch_size(&self) -> Result, String> { + // Check if 'n' parameter is present and > 1 + if let Some(n_value) = self.other.get("n") { + if let Some(n) = n_value.as_u64() { + if n > 1 { + return Ok(Some(n as usize)); + } + } + } + Ok(None) + } + + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ) { + self.bootstrap_host = Some(bootstrap_host); + self.bootstrap_port = Some(bootstrap_port); + self.bootstrap_room = Some(bootstrap_room); + } +} diff --git a/sgl-router/src/request_adapter.rs b/sgl-router/src/request_adapter.rs new file mode 100644 index 000000000..4396cc4d7 --- /dev/null +++ b/sgl-router/src/request_adapter.rs @@ -0,0 +1,264 @@ +// Request adapter to bridge OpenAI API types with PD routing requirements + +use crate::openai_api_types::{ + ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray, +}; +use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch}; +use serde_json::Value; + +/// Adapter trait to convert OpenAI requests to PD-compatible requests +pub trait ToPdRequest { + type Output: Bootstrap; + fn to_pd_request(self) -> Self::Output; +} + +// Helper macro to insert optional fields into a map +macro_rules! insert_if_some { + ($map:expr, $($field:expr => $key:expr),* $(,)?) => { + $( + if let Some(value) = $field { + $map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null)); + } + )* + }; +} + +// Helper macro for simple value insertions +macro_rules! insert_value { + ($map:expr, $($field:expr => $key:expr),* $(,)?) => { + $( + $map.insert($key.to_string(), $field.into()); + )* + }; +} + +// ============= Generate Request Adapter ============= + +impl ToPdRequest for GenerateRequest { + type Output = GenerateReqInput; + + fn to_pd_request(self) -> Self::Output { + // Build the other fields first + let mut other = serde_json::Map::new(); + + // Handle text input - check in priority order: text (SGLang), prompt (OpenAI) + let (text, input_ids) = if let Some(text_str) = self.text { + // SGLang native format + (Some(SingleOrBatch::Single(text_str)), None) + } else if let Some(prompt) = self.prompt { + // OpenAI style prompt + let text = match prompt { + StringOrArray::String(s) => Some(SingleOrBatch::Single(s)), + StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)), + }; + (text, None) + } else if let Some(ids) = self.input_ids { + // Input IDs case + let input_ids = match ids { + crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)), + crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)), + }; + (None, input_ids) + } else { + // No input provided + (None, None) + }; + + // Add parameters to other - handle both old and new style + if let Some(params) = self.parameters { + // For generate endpoint, extract max_new_tokens to top level if present + let mut params_value = serde_json::to_value(¶ms).unwrap_or(Value::Null); + if let Value::Object(ref mut params_map) = params_value { + // Move max_new_tokens to top level if it exists + if let Some(max_new_tokens) = params_map.remove("max_new_tokens") { + other.insert("max_new_tokens".to_string(), max_new_tokens); + } + // Move temperature to top level if it exists + if let Some(temperature) = params_map.remove("temperature") { + other.insert("temperature".to_string(), temperature); + } + } + // Only add parameters if there are remaining fields + if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty()) + { + other.insert("parameters".to_string(), params_value); + } + } + + // Add sampling_params if present + if let Some(sampling_params) = self.sampling_params { + let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null); + if !params_value.is_null() { + // Extract commonly used fields to top level + if let Value::Object(ref params_map) = params_value { + if let Some(max_new_tokens) = params_map.get("max_new_tokens") { + other.insert("max_new_tokens".to_string(), max_new_tokens.clone()); + } + if let Some(temperature) = params_map.get("temperature") { + other.insert("temperature".to_string(), temperature.clone()); + } + } + other.insert("sampling_params".to_string(), params_value); + } + } + + // Add other fields + insert_value!(other, + self.stream => "stream", + self.return_logprob => "return_logprob" + ); + + GenerateReqInput { + text, + input_ids, + stream: self.stream, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(other), + } + } +} + +// ============= Completion Request Adapter ============= + +impl ToPdRequest for CompletionRequest { + type Output = GenerateReqInput; + + fn to_pd_request(self) -> Self::Output { + // Convert CompletionRequest to GenerateReqInput + let text = match self.prompt { + StringOrArray::String(s) => Some(SingleOrBatch::Single(s)), + StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)), + }; + + // Map OpenAI parameters to generate parameters + let mut other = serde_json::Map::new(); + + // Create parameters object + let mut params = serde_json::Map::new(); + + // Map OpenAI fields to internal parameter names + insert_if_some!(params, + self.max_tokens => "max_new_tokens", + self.temperature => "temperature", + self.top_p => "top_p", + self.n => "best_of", + self.logprobs => "top_n_tokens", + self.seed => "seed" + ); + + // Special handling for fields that need transformation + if let Some(presence_penalty) = self.presence_penalty { + params.insert( + "repetition_penalty".to_string(), + (1.0 + presence_penalty).into(), + ); + } + + if let Some(stop) = self.stop { + let stop_sequences = match stop { + StringOrArray::String(s) => vec![s], + StringOrArray::Array(v) => v, + }; + params.insert("stop".to_string(), stop_sequences.into()); + } + + if self.echo { + params.insert("return_full_text".to_string(), true.into()); + } + + other.insert("parameters".to_string(), Value::Object(params)); + + // Store original model and stream flag + insert_value!(other, + self.model => "model", + self.stream => "stream" + ); + + GenerateReqInput { + text, + input_ids: None, + stream: self.stream, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(other), + } + } +} + +// ============= Chat Completion Request Adapter ============= + +impl ToPdRequest for ChatCompletionRequest { + type Output = ChatReqInput; + + fn to_pd_request(self) -> Self::Output { + let mut other = serde_json::Map::new(); + + // Add required fields + insert_if_some!(other, + Some(&self.messages) => "messages" + ); + + insert_value!(other, + self.model => "model", + self.stream => "stream" + ); + + // Add all optional fields + insert_if_some!(other, + self.temperature => "temperature", + self.top_p => "top_p", + self.n => "n", + self.stop => "stop", + self.max_tokens => "max_tokens", + self.max_completion_tokens => "max_completion_tokens", + self.presence_penalty => "presence_penalty", + self.frequency_penalty => "frequency_penalty", + self.logit_bias => "logit_bias", + self.user => "user", + self.seed => "seed", + self.top_logprobs => "top_logprobs", + self.response_format => "response_format", + self.tools => "tools", + self.tool_choice => "tool_choice", + self.parallel_tool_calls => "parallel_tool_calls", + self.functions => "functions", + self.function_call => "function_call" + ); + + // Handle boolean logprobs flag + if self.logprobs { + other.insert("logprobs".to_string(), true.into()); + } + + ChatReqInput { + stream: self.stream, + bootstrap_host: None, + bootstrap_port: None, + bootstrap_room: None, + other: Value::Object(other), + } + } +} + +// ============= Direct routing support for regular router ============= + +/// Extension trait for routing without PD conversion +pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone { + /// Convert to JSON for sending to backend + fn to_json(&self) -> Result { + serde_json::to_value(self) + } + + /// Convert to bytes for legacy routing + fn to_bytes(&self) -> Result { + let json = serde_json::to_vec(self)?; + Ok(bytes::Bytes::from(json)) + } +} + +impl RouteableRequest for GenerateRequest {} +impl RouteableRequest for CompletionRequest {} +impl RouteableRequest for ChatCompletionRequest {} diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index 1666cc0db..9e4096311 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -1,10 +1,10 @@ +use crate::pd_router::PDRouter; +use crate::pd_types::PDSelectionPolicy; use crate::tree::Tree; use ::metrics::{counter, gauge, histogram}; use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; -use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use serde_json::Value; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; @@ -15,7 +15,7 @@ use std::time::Instant; use tokio; use tracing::{debug, error, info, warn}; -fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { +pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { req.headers() .iter() .filter_map(|(name, value)| { @@ -40,6 +40,9 @@ pub enum Router { timeout_secs: u64, interval_secs: u64, }, + PrefillDecode { + pd_router: Arc, + }, CacheAware { /* Cache-Aware Load Balancing Router @@ -133,6 +136,13 @@ pub enum PolicyConfig { timeout_secs: u64, interval_secs: u64, }, + PrefillDecodeConfig { + selection_policy: PDSelectionPolicy, + prefill_urls: Vec<(String, Option)>, // (url, bootstrap_port) + decode_urls: Vec, + timeout_secs: u64, + interval_secs: u64, + }, } impl Router { @@ -155,10 +165,24 @@ impl Router { interval_secs, .. } => (*timeout_secs, *interval_secs), + PolicyConfig::PrefillDecodeConfig { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), }; - // Wait until all workers are healthy - Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; + // For PrefillDecode, we need to handle workers differently + match &policy_config { + PolicyConfig::PrefillDecodeConfig { .. } => { + // PD mode doesn't use the worker_urls parameter + // We'll validate PD workers separately + } + _ => { + // Wait until all workers are healthy for regular modes + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; + } + } // Create router based on policy... Ok(match policy_config { @@ -226,7 +250,7 @@ impl Router { }); for url in &worker_urls { - tree.lock().unwrap().insert(&"".to_string(), url); + tree.lock().unwrap().insert("", url); } Router::CacheAware { @@ -242,6 +266,26 @@ impl Router { _eviction_thread: Some(eviction_thread), } } + PolicyConfig::PrefillDecodeConfig { + selection_policy, + prefill_urls, + decode_urls, + timeout_secs, + interval_secs, + } => { + // Create PDRouter instance + let pd_router = PDRouter::new( + prefill_urls, + decode_urls, + selection_policy, + timeout_secs, + interval_secs, + )?; + + Router::PrefillDecode { + pd_router: Arc::new(pd_router), + } + } }) } @@ -251,16 +295,23 @@ impl Router { Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls), Router::Random { worker_urls, .. } => Arc::clone(worker_urls), Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls), + Router::PrefillDecode { .. } => { + // For PD mode, return empty list since we manage workers differently + Arc::new(RwLock::new(Vec::new())) + } } } - fn wait_for_healthy_workers( + pub fn wait_for_healthy_workers( worker_urls: &[String], timeout_secs: u64, interval_secs: u64, ) -> Result<(), String> { let start_time = std::time::Instant::now(); - let sync_client = reqwest::blocking::Client::new(); + let sync_client = reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { @@ -323,10 +374,14 @@ impl Router { Ok(worker_urls.read().unwrap()[0].clone()) } } + Router::PrefillDecode { .. } => { + // For PD mode, we don't need this method as routing is handled by PDRouter + Err("PrefillDecode mode doesn't use select_first_worker".to_string()) + } } } - async fn send_request( + pub async fn send_request( &self, client: &reqwest::Client, worker_url: &str, @@ -339,7 +394,11 @@ impl Router { // Copy all headers from original request except for /health because it does not need authorization if route != "/health" { for (name, value) in copy_request_headers(req) { - request_builder = request_builder.header(name, value); + // Skip Content-Type and Content-Length as .json() sets them + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" + { + request_builder = request_builder.header(name, value); + } } } @@ -433,50 +492,193 @@ impl Router { HttpResponse::InternalServerError().body("All retry attempts failed") } - fn get_text_from_request(&self, body: &Bytes, route: &str) -> String { - // Convert body to JSON - let json: Value = match serde_json::from_slice(body) { - Ok(j) => j, - Err(_) => { - warn!("Failed to parse JSON from request body."); - return String::new(); + pub async fn route_to_all( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { + // Get all worker URLs based on router type + let worker_urls = match self { + Router::PrefillDecode { .. } => { + // For PD mode, route_to_all is not supported directly + // It should be handled by PDRouter if needed + return HttpResponse::NotImplemented() + .body("route_to_all not implemented for PrefillDecode mode"); } + _ => self.get_worker_urls().read().unwrap().clone(), }; - match route { - "/generate" => { - // For /generate, always use the "text" field. - match json.get("text").and_then(Value::as_str) { - Some(text) => text.to_string(), - None => { - warn!("No 'text' field found in request body for route /generate."); - String::new() - } - } + // Send requests to all workers concurrently + let mut tasks = Vec::new(); + for worker_url in &worker_urls { + let mut request_builder = client.post(format!("{}{}", worker_url, route)); + + // Copy headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); } - "/v1/chat/completions" | "/v1/completions" => { - // For these routes, try "messages", then "prompt", then "text". - if let Some(messages) = json.get("messages") { - serde_json::to_string(messages).unwrap_or_default() - } else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) { - prompt.to_string() - } else { - warn!("Failed to find 'messages', 'prompt' in request body."); - String::new() - } + + tasks.push(request_builder.send()); + } + + // Wait for all responses + let results = futures_util::future::join_all(tasks).await; + + // Check if all succeeded + let all_success = results.iter().all(|r| { + r.as_ref() + .map(|res| res.status().is_success()) + .unwrap_or(false) + }); + + if all_success { + HttpResponse::Ok().body("Operation completed on all servers") + } else { + HttpResponse::InternalServerError().body("Operation failed on one or more servers") + } + } + + pub async fn get_all_loads( + &self, + client: &reqwest::Client, + _req: &HttpRequest, + ) -> HttpResponse { + // For PD mode, delegate to PDRouter + match self { + Router::PrefillDecode { pd_router } => { + return pd_router.get_loads(client).await; } _ => { - warn!("Unknown route: {} - defaulting to fallback string", route); - String::new() + // For non-PD routers, handle normally + } + } + + let urls = self.get_worker_urls().read().unwrap().clone(); + let prefill_urls: Vec = Vec::new(); + let decode_urls = urls; + + // Collect loads from all servers + let mut prefill_loads = Vec::new(); + let mut decode_loads = Vec::new(); + + // Get prefill loads + for url in &prefill_urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + prefill_loads.push(serde_json::json!({ + "engine": format!("(Prefill@{})", url), + "load": load as i64 + })); + } + + // Get decode loads + for url in &decode_urls { + let load = self.get_worker_load(client, url).await.unwrap_or(-1); + decode_loads.push(serde_json::json!({ + "engine": format!("(Decode@{})", url), + "load": load as i64 + })); + } + + HttpResponse::Ok().json(serde_json::json!({ + "prefill": prefill_loads, + "decode": decode_loads + })) + } + + // New method to route typed requests directly + pub async fn route_typed_request< + T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone, + >( + &self, + client: &reqwest::Client, + req: &HttpRequest, + typed_req: &T, + route: &str, + ) -> HttpResponse { + match self { + Router::PrefillDecode { .. } => HttpResponse::InternalServerError() + .body("PD routing should use specialized typed handlers"), + _ => { + // Handle retries like the original implementation + let start = Instant::now(); + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + // Extract routing text directly from typed request + let text = typed_req.extract_text_for_routing(); + let is_stream = typed_req.is_stream(); + + // Select worker based on text + let worker_url = self.select_generate_worker_from_text(&text); + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + counter!("sgl_router_retries_total", "route" => route.to_string()) + .increment(1); + } + + // Send typed request directly + let response = self + .send_typed_request( + client, + req, + typed_req, + route, + &worker_url, + is_stream, + ) + .await; + + if response.status().is_success() { + let duration = start.elapsed(); + histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + counter!("sgl_router_request_errors_total", "route" => route.to_string()) + .increment(1); + return response; + } + } + + warn!( + "Generate request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + + counter!("sgl_router_request_errors_total", "route" => route.to_string()) + .increment(1); + HttpResponse::InternalServerError().body("All retry attempts failed") } } } - // TODO: return Result instead of panicking - fn select_generate_worker(&self, body: &Bytes, route: &str) -> String { - let text = self.get_text_from_request(&body, route); - - let worker_url = match self { + // Helper method to select worker from text + fn select_generate_worker_from_text(&self, text: &str) -> String { + match self { Router::RoundRobin { worker_urls, current_index, @@ -506,8 +708,6 @@ impl Router { balance_rel_threshold, .. } => { - // TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones - let tree = tree.lock().unwrap(); let mut running_queue = running_queue.lock().unwrap(); @@ -572,35 +772,48 @@ impl Router { selected_url } - }; - - worker_url + Router::PrefillDecode { .. } => { + // For PD mode, we don't use this method + return "PD_MODE_ERROR".to_string(); + } + } } - async fn send_generate_request( + // Send typed request directly without conversion + async fn send_typed_request( &self, client: &reqwest::Client, req: &HttpRequest, - body: &Bytes, + typed_req: &T, route: &str, worker_url: &str, + is_stream: bool, ) -> HttpResponse { - let is_stream = serde_json::from_slice::(&body) - .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) - .unwrap_or(false); + let start = Instant::now(); + + // Debug: Log what we're sending + if let Ok(json_str) = serde_json::to_string_pretty(typed_req) { + debug!("Sending request to {}: {}", route, json_str); + } let mut request_builder = client .post(format!("{}{}", worker_url, route)) - .body(body.to_vec()); + .json(typed_req); // Use json() directly with typed request // Copy all headers from original request for (name, value) in copy_request_headers(req) { - request_builder = request_builder.header(name, value); + // Skip Content-Type and Content-Length as .json() sets them + if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" { + request_builder = request_builder.header(&name, &value); + } } let res = match request_builder.send().await { Ok(res) => res, - Err(_) => return HttpResponse::InternalServerError().finish(), + Err(e) => { + error!("Failed to send request to {}: {}", worker_url, e); + return HttpResponse::InternalServerError().body(format!("Request failed: {}", e)); + } }; let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) @@ -625,6 +838,12 @@ impl Router { } } + // Record metrics + let duration = start.elapsed(); + histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()) + .record(duration.as_secs_f64()); + counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1); + response } else if let Router::CacheAware { running_queue, .. } = self { let running_queue = Arc::clone(running_queue); @@ -660,70 +879,6 @@ impl Router { } } - pub async fn route_generate_request( - &self, - client: &reqwest::Client, - req: &HttpRequest, - body: &Bytes, - route: &str, - ) -> HttpResponse { - let start = Instant::now(); - const MAX_REQUEST_RETRIES: u32 = 3; - const MAX_TOTAL_RETRIES: u32 = 6; - let mut total_retries = 0; - - while total_retries < MAX_TOTAL_RETRIES { - let worker_url = self.select_generate_worker(body, route); - let mut request_retries = 0; - - // Try the same worker multiple times - while request_retries < MAX_REQUEST_RETRIES { - if total_retries >= 1 { - info!("Retrying request after {} failed attempts", total_retries); - counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1); - } - - let response = self - .send_generate_request(client, req, body, route, &worker_url) - .await; - - if response.status().is_success() { - let duration = start.elapsed(); - histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64()); - return response; - } else { - // if the worker is healthy, it means the request is bad, so return the error response - let health_response = - self.send_request(client, &worker_url, "/health", req).await; - if health_response.status().is_success() { - counter!("sgl_router_request_errors_total", "route" => route.to_string()) - .increment(1); - return response; - } - } - - warn!( - "Generate request to {} failed (attempt {}/{})", - worker_url, - request_retries + 1, - MAX_REQUEST_RETRIES - ); - - request_retries += 1; - total_retries += 1; - - if request_retries == MAX_REQUEST_RETRIES { - warn!("Removing failed worker: {}", worker_url); - self.remove_worker(&worker_url); - break; - } - } - } - - counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1); - HttpResponse::InternalServerError().body("All retry attempts failed") - } - pub async fn add_worker(&self, worker_url: &str) -> Result { let (timeout_secs, interval_secs) = match self { Router::Random { @@ -741,10 +896,17 @@ impl Router { interval_secs, .. } => (*timeout_secs, *interval_secs), + Router::PrefillDecode { .. } => { + // For PD mode, we don't support adding workers via this method + return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); + } }; let start_time = std::time::Instant::now(); - let client = reqwest::Client::new(); + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(timeout_secs)) + .build() + .map_err(|e| format!("Failed to create HTTP client: {}", e))?; loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { @@ -774,6 +936,9 @@ impl Router { urls.push(worker_url.to_string()); gauge!("sgl_router_active_workers").set(urls.len() as f64); } + Router::PrefillDecode { .. } => { + return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string()); + } } // If cache aware, initialize the queues for the new worker @@ -797,7 +962,7 @@ impl Router { .insert(worker_url.to_string(), 0); // Add worker to tree - tree.lock().unwrap().insert(&"".to_string(), &worker_url); + tree.lock().unwrap().insert("", worker_url); } return Ok(format!("Successfully added worker: {}", worker_url)); @@ -850,6 +1015,10 @@ impl Router { return; } } + Router::PrefillDecode { .. } => { + warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods."); + return; + } } // if cache aware, remove the worker from the tree @@ -875,4 +1044,133 @@ impl Router { ); } } + + async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option { + match client.get(&format!("{}/get_load", worker_url)).send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(bytes) => match serde_json::from_slice::(&bytes) { + Ok(data) => data + .get("load") + .and_then(|v| v.as_i64()) + .map(|v| v as isize), + Err(e) => { + debug!("Failed to parse load response from {}: {}", worker_url, e); + None + } + }, + Err(e) => { + debug!("Failed to read load response from {}: {}", worker_url, e); + None + } + }, + Ok(res) => { + debug!( + "Worker {} returned non-success status: {}", + worker_url, + res.status() + ); + None + } + Err(e) => { + debug!("Failed to get load from {}: {}", worker_url, e); + None + } + } + } + + // PD-specific wrapper methods that delegate to PDRouter + pub async fn route_pd_health_generate( + &self, + _client: &reqwest::Client, + _req: &HttpRequest, + ) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router.health_generate(&pd_router.http_client).await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } + + pub async fn route_pd_generate_typed( + &self, + _client: &reqwest::Client, + req: &HttpRequest, + typed_req: crate::pd_types::GenerateReqInput, + route: &str, + ) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router + .route_generate(&pd_router.http_client, req, typed_req, route) + .await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } + + pub async fn route_pd_chat_typed( + &self, + _client: &reqwest::Client, + req: &HttpRequest, + typed_req: crate::pd_types::ChatReqInput, + route: &str, + ) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router + .route_chat(&pd_router.http_client, req, typed_req, route) + .await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } + + pub async fn get_pd_server_info( + &self, + _client: &reqwest::Client, + _req: &HttpRequest, + ) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router.get_server_info(&pd_router.http_client).await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } + + pub async fn get_pd_models( + &self, + _client: &reqwest::Client, + req: &HttpRequest, + ) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router.get_models(&pd_router.http_client, req).await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } + + pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router.flush_cache(&pd_router.http_client).await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } + + pub async fn get_pd_model_info( + &self, + _client: &reqwest::Client, + req: &HttpRequest, + ) -> HttpResponse { + match self { + Router::PrefillDecode { pd_router } => { + pd_router.get_model_info(&pd_router.http_client, req).await + } + _ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"), + } + } } diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 0d6cf6910..a1fd50acd 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -1,12 +1,13 @@ use crate::logging::{self, LoggingConfig}; +use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::prometheus::{self, PrometheusConfig}; +use crate::request_adapter::ToPdRequest; use crate::router::PolicyConfig; use crate::router::Router; use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig}; use actix_web::{ error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder, }; -use bytes::Bytes; use futures_util::StreamExt; use reqwest::Client; use std::collections::HashMap; @@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level}; pub struct AppState { router: Arc, client: Client, + is_pd_mode: bool, // Add flag to track PD mode } impl AppState { @@ -28,9 +30,16 @@ impl AppState { client: Client, policy_config: PolicyConfig, ) -> Result { + // Check if this is PD mode from policy config + let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. }); + // Create router based on policy let router = Arc::new(Router::new(worker_urls, policy_config)?); - Ok(Self { router, client }) + Ok(Self { + router, + client, + is_pd_mode, + }) } } @@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result Error { - error::ErrorPayloadTooLarge("Payload too large") +fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error { + error!("JSON payload error: {:?}", err); + match &err { + error::JsonPayloadError::OverflowKnownLength { length, limit } => { + error!( + "Payload too large: {} bytes exceeds limit of {} bytes", + length, limit + ); + error::ErrorPayloadTooLarge(format!( + "Payload too large: {} bytes exceeds limit of {} bytes", + length, limit + )) + } + error::JsonPayloadError::Overflow { limit } => { + error!("Payload overflow: exceeds limit of {} bytes", limit); + error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit)) + } + _ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)), + } } #[get("/health")] @@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data) -> impl Responder { #[get("/health_generate")] async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { - data.router - .route_to_first(&data.client, "/health_generate", &req) - .await + // Check if we're in PD mode + if data.is_pd_mode { + // For PD mode, check health on all servers + data.router + .route_pd_health_generate(&data.client, &req) + .await + } else { + // Regular mode + data.router + .route_to_first(&data.client, "/health_generate", &req) + .await + } } #[get("/get_server_info")] async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { - data.router - .route_to_first(&data.client, "/get_server_info", &req) - .await + if data.is_pd_mode { + // For PD mode, aggregate info from both prefill and decode servers + data.router.get_pd_server_info(&data.client, &req).await + } else { + // Regular mode - return first server's info + data.router + .route_to_first(&data.client, "/get_server_info", &req) + .await + } } #[get("/v1/models")] async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { - data.router - .route_to_first(&data.client, "/v1/models", &req) - .await + if data.is_pd_mode { + // For PD mode, return models from the first prefill server + data.router.get_pd_models(&data.client, &req).await + } else { + // Regular mode + data.router + .route_to_first(&data.client, "/v1/models", &req) + .await + } } #[get("/get_model_info")] async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { - data.router - .route_to_first(&data.client, "/get_model_info", &req) - .await + if data.is_pd_mode { + // For PD mode, get model info from the first prefill server + data.router.get_pd_model_info(&data.client, &req).await + } else { + data.router + .route_to_first(&data.client, "/get_model_info", &req) + .await + } } #[post("/generate")] -async fn generate(req: HttpRequest, body: Bytes, data: web::Data) -> impl Responder { - data.router - .route_generate_request(&data.client, &req, &body, "/generate") - .await +async fn generate( + req: HttpRequest, + body: web::Json, + state: web::Data, +) -> Result { + let client = &state.client; + let router = &state.router; + + // Use typed request directly for both PD and regular routing + if state.is_pd_mode { + // For PD mode, convert to PD request with bootstrap + let pd_request = body.into_inner().to_pd_request(); + + Ok(router + .route_pd_generate_typed(&client, &req, pd_request, "/generate") + .await) + } else { + // For regular mode, use typed request directly + let request = body.into_inner(); + Ok(router + .route_typed_request(&client, &req, &request, "/generate") + .await) + } } #[post("/v1/chat/completions")] async fn v1_chat_completions( req: HttpRequest, - body: Bytes, - data: web::Data, -) -> impl Responder { - data.router - .route_generate_request(&data.client, &req, &body, "/v1/chat/completions") - .await + body: web::Json, + state: web::Data, +) -> Result { + let client = &state.client; + let router = &state.router; + + // Use typed request directly for both PD and regular routing + if state.is_pd_mode { + // For PD mode, convert to PD request with bootstrap + let pd_request = body.into_inner().to_pd_request(); + + Ok(router + .route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions") + .await) + } else { + // For regular mode, use typed request directly + let request = body.into_inner(); + Ok(router + .route_typed_request(&client, &req, &request, "/v1/chat/completions") + .await) + } } #[post("/v1/completions")] async fn v1_completions( req: HttpRequest, - body: Bytes, - data: web::Data, -) -> impl Responder { - data.router - .route_generate_request(&data.client, &req, &body, "/v1/completions") - .await + body: web::Json, + state: web::Data, +) -> Result { + let client = &state.client; + let router = &state.router; + + // Use typed request directly for both PD and regular routing + if state.is_pd_mode { + // For PD mode, convert to PD request with bootstrap + let pd_request = body.into_inner().to_pd_request(); + + Ok(router + .route_pd_generate_typed(&client, &req, pd_request, "/v1/completions") + .await) + } else { + // For regular mode, use typed request directly + let request = body.into_inner(); + Ok(router + .route_typed_request(&client, &req, &request, "/v1/completions") + .await) + } } #[post("/add_worker")] @@ -153,6 +254,25 @@ async fn remove_worker( HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url)) } +#[post("/flush_cache")] +async fn flush_cache(req: HttpRequest, data: web::Data) -> impl Responder { + if data.is_pd_mode { + // For PD mode, flush cache on both prefill and decode servers + data.router.route_pd_flush_cache(&data.client).await + } else { + // Route to all workers for cache flushing + data.router + .route_to_all(&data.client, "/flush_cache", &req) + .await + } +} + +#[get("/get_loads")] +async fn get_loads(req: HttpRequest, data: web::Data) -> impl Responder { + // Get loads from all workers + data.router.get_all_loads(&data.client, &req).await +} + pub struct ServerConfig { pub host: String, pub port: u16, @@ -163,6 +283,7 @@ pub struct ServerConfig { pub log_dir: Option, pub service_discovery_config: Option, pub prometheus_config: Option, + pub request_timeout_secs: u64, } pub async fn startup(config: ServerConfig) -> std::io::Result<()> { @@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { let client = Client::builder() .pool_idle_timeout(Some(Duration::from_secs(50))) + .timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout .build() .expect("Failed to create HTTP client"); @@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { .service(add_worker) .service(remove_worker) .service(list_workers) - // Default handler for unmatched routes. + .service(flush_cache) + .service(get_loads) .default_service(web::route().to(sink_handler)) }) .bind_auto_h2c((config.host, config.port))? diff --git a/sgl-router/tests/test_pd_routing.rs b/sgl-router/tests/test_pd_routing.rs new file mode 100644 index 000000000..6d9019b43 --- /dev/null +++ b/sgl-router/tests/test_pd_routing.rs @@ -0,0 +1,904 @@ +//! Comprehensive tests for PrefillDecode (PD) routing functionality +//! +//! This test suite covers: +//! - Phase 1: Basic PD router creation and configuration +//! - Phase 2: Bootstrap injection and request handling +//! - Phase 3: Cache-aware selection (when implemented) +//! +//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type. +//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode. + +#[cfg(test)] +mod test_pd_routing { + use rand::Rng; + use serde_json::json; + use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy}; + use sglang_router_rs::router::{PolicyConfig, Router}; + + // Test-only struct to help validate PD request parsing + #[derive(Debug)] + struct PDRequest { + pub is_stream: bool, + pub batch_size: Option, + } + + impl PDRequest { + // Extract PD-relevant info from JSON for testing + pub fn from_json(json: &serde_json::Value) -> Self { + let is_stream = json + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + // Detect batch size from text or input_ids + let batch_size = if let Some(text) = json.get("text") { + text.as_array().map(|arr| arr.len()) + } else if let Some(input_ids) = json.get("input_ids") { + input_ids.as_array().map(|arr| arr.len()) + } else { + None + }; + + PDRequest { + is_stream, + batch_size, + } + } + } + + // ======================================================================== + // Phase 1: Basic PD Components and Router Creation + // ======================================================================== + + #[test] + fn test_engine_info_creation() { + // Test EngineInfo creation for prefill servers + let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); + match prefill_engine.engine_type { + EngineType::Prefill => (), + _ => panic!("Expected Prefill engine type"), + } + assert_eq!(prefill_engine.url, "http://prefill:8080"); + assert_eq!(prefill_engine.bootstrap_port, Some(9000)); + assert_eq!(prefill_engine.get_hostname(), "prefill"); + + // Test EngineInfo creation for decode servers + let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string()); + match decode_engine.engine_type { + EngineType::Decode => (), + _ => panic!("Expected Decode engine type"), + } + assert_eq!(decode_engine.url, "http://decode:8080"); + assert_eq!(decode_engine.bootstrap_port, None); + assert_eq!(decode_engine.get_hostname(), "decode"); + + // Test API path generation + assert_eq!( + prefill_engine.api_path("/generate"), + "http://prefill:8080/generate" + ); + assert_eq!( + prefill_engine.api_path("health"), + "http://prefill:8080/health" + ); + assert_eq!( + decode_engine.api_path("/v1/chat/completions"), + "http://decode:8080/v1/chat/completions" + ); + } + + #[test] + fn test_pd_selection_policies() { + // Test all PD selection policy variants + // Note: These policies are only used when pd_disaggregated=true + let policies = vec![ + PDSelectionPolicy::Random, + PDSelectionPolicy::PowerOfTwo, + PDSelectionPolicy::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + }, + ]; + + for policy in policies { + // Verify each policy can be created and matched + match &policy { + PDSelectionPolicy::Random => { + assert!(matches!(policy, PDSelectionPolicy::Random)); + } + PDSelectionPolicy::PowerOfTwo => { + assert!(matches!(policy, PDSelectionPolicy::PowerOfTwo)); + } + PDSelectionPolicy::CacheAware { + cache_threshold, .. + } => { + assert!(*cache_threshold >= 0.0 && *cache_threshold <= 1.0); + } + } + } + } + + #[test] + fn test_pd_router_configuration() { + // Test PrefillDecodeConfig creation with various policies + // This config is used when pd_disaggregated=true + let configs = vec![ + PolicyConfig::PrefillDecodeConfig { + selection_policy: PDSelectionPolicy::Random, + prefill_urls: vec![ + ("http://prefill1:8080".to_string(), Some(9000)), + ("http://prefill2:8080".to_string(), None), + ], + decode_urls: vec![ + "http://decode1:8080".to_string(), + "http://decode2:8080".to_string(), + ], + timeout_secs: 10, + interval_secs: 1, + }, + PolicyConfig::PrefillDecodeConfig { + selection_policy: PDSelectionPolicy::PowerOfTwo, + prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))], + decode_urls: vec!["http://decode:8080".to_string()], + timeout_secs: 5, + interval_secs: 1, + }, + PolicyConfig::PrefillDecodeConfig { + selection_policy: PDSelectionPolicy::CacheAware { + cache_threshold: 0.7, + balance_abs_threshold: 20, + balance_rel_threshold: 1.2, + }, + prefill_urls: vec![ + ("http://p1:8080".to_string(), Some(9000)), + ("http://p2:8080".to_string(), Some(9001)), + ("http://p3:8080".to_string(), Some(9002)), + ], + decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()], + timeout_secs: 10, + interval_secs: 2, + }, + ]; + + for config in configs { + // Router creation will fail due to health checks, but config should be valid + let result = Router::new(vec![], config); + assert!(result.is_err()); + let error_msg = result.unwrap_err(); + // Error should be about health/timeout, not configuration + assert!( + error_msg.contains("healthy") || error_msg.contains("timeout"), + "Unexpected error: {}", + error_msg + ); + } + } + + // ======================================================================== + // Phase 2: Bootstrap Injection and Request Handling + // ======================================================================== + + #[test] + fn test_pd_request_from_json() { + // Test PDRequest parsing from single text request + let single_json = json!({ + "text": "Hello world", + "stream": false, + "temperature": 0.7, + "max_tokens": 100 + }); + + let pd_req = PDRequest::from_json(&single_json); + assert!(!pd_req.is_stream); + assert_eq!(pd_req.batch_size, None); + + // Test PDRequest parsing from batch text request + let batch_json = json!({ + "text": ["Hello", "World", "Test"], + "stream": true, + "temperature": 0.5 + }); + + let pd_req = PDRequest::from_json(&batch_json); + assert!(pd_req.is_stream); + assert_eq!(pd_req.batch_size, Some(3)); + + // Test PDRequest parsing from input_ids request + let ids_json = json!({ + "input_ids": [[1, 2, 3], [4, 5, 6]], + "stream": false + }); + + let pd_req = PDRequest::from_json(&ids_json); + assert!(!pd_req.is_stream); + assert_eq!(pd_req.batch_size, Some(2)); + + // Test PDRequest parsing from chat request + let chat_json = json!({ + "messages": [ + {"role": "system", "content": "You are a helpful assistant"}, + {"role": "user", "content": "Hello"} + ], + "stream": true + }); + + let pd_req = PDRequest::from_json(&chat_json); + assert!(pd_req.is_stream); + assert_eq!(pd_req.batch_size, None); + } + + #[test] + fn test_bootstrap_injection_simulation() { + // Since we can't test the actual inject_bootstrap_fields function here + // (it's private in the router module), we'll test the expected behavior + + // Simulate bootstrap injection for single request + let mut single_json = json!({ + "text": "Hello world", + "stream": false, + "temperature": 0.7 + }); + + // Simulate what inject_bootstrap_fields would do + let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000)); + single_json["bootstrap_host"] = json!(prefill_info.get_hostname()); + single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port); + single_json["bootstrap_room"] = json!(12345u64); // Random room ID + + // Verify bootstrap fields are added correctly + assert_eq!(single_json["bootstrap_host"], "prefill1"); + assert_eq!(single_json["bootstrap_port"], 9000); + assert!(single_json["bootstrap_room"].is_u64()); + assert_eq!(single_json["temperature"], 0.7); // Original field preserved + + // Simulate bootstrap injection for batch request + let mut batch_json = json!({ + "text": ["Hello", "World", "Test"], + "stream": true + }); + + let batch_size = 3; + batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); + batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); + batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]); + + // Verify batch bootstrap fields + assert!(batch_json["bootstrap_host"].is_array()); + assert_eq!( + batch_json["bootstrap_host"].as_array().unwrap().len(), + batch_size + ); + assert!(batch_json["bootstrap_port"].is_array()); + assert!(batch_json["bootstrap_room"].is_array()); + assert_eq!(batch_json["stream"], true); // Original field preserved + } + + #[test] + fn test_request_serialization() { + // Test that requests can be properly serialized and deserialized + let request = json!({ + "text": "Test prompt", + "stream": false, + "temperature": 0.7, + "max_tokens": 100, + "top_p": 0.9, + "frequency_penalty": 0.5, + "bootstrap_host": "prefill1", + "bootstrap_port": 9000, + "bootstrap_room": 12345u64 + }); + + // Convert to bytes (as would happen in the router) + let bytes = serde_json::to_vec(&request).unwrap(); + + // Parse back from bytes + let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); + + // Verify all fields are preserved + assert_eq!(parsed["text"], "Test prompt"); + assert_eq!(parsed["stream"], false); + assert_eq!(parsed["temperature"], 0.7); + assert_eq!(parsed["max_tokens"], 100); + assert_eq!(parsed["bootstrap_host"], "prefill1"); + assert_eq!(parsed["bootstrap_port"], 9000); + assert_eq!(parsed["bootstrap_room"], 12345); + } + + #[test] + fn test_engine_info_hostname_extraction() { + // Test various URL formats + let test_cases = vec![ + ("http://localhost:8080", "localhost"), + ("http://10.0.0.1:8080", "10.0.0.1"), + ("https://api.example.com:443", "api.example.com"), + ("http://prefill-server", "prefill-server"), + ("http://[::1]:8080", "["), // IPv6 edge case + ("prefill:8080", "prefill"), // No protocol + ]; + + for (url, expected_hostname) in test_cases { + let engine = EngineInfo::new_prefill(url.to_string(), None); + assert_eq!(engine.get_hostname(), expected_hostname); + } + } + + #[test] + fn test_pd_request_edge_cases() { + // Test empty request + let empty_json = json!({}); + let pd_req = PDRequest::from_json(&empty_json); + assert!(!pd_req.is_stream); + assert_eq!(pd_req.batch_size, None); + + // Test request with only stream field + let stream_only = json!({ + "stream": true + }); + let pd_req = PDRequest::from_json(&stream_only); + assert!(pd_req.is_stream); + assert_eq!(pd_req.batch_size, None); + + // Test request with empty text array + let empty_batch = json!({ + "text": [] + }); + let pd_req = PDRequest::from_json(&empty_batch); + assert_eq!(pd_req.batch_size, Some(0)); + + // Test request with non-array text (should be None) + let non_array_text = json!({ + "text": "single string" + }); + let pd_req = PDRequest::from_json(&non_array_text); + assert_eq!(pd_req.batch_size, None); + } + + // ======================================================================== + // Phase 2: Background Load Monitoring Tests + // ======================================================================== + + #[tokio::test] + async fn test_background_load_monitoring() { + use std::collections::HashMap; + use tokio::sync::watch; + + // Create a watch channel for testing + let (tx, rx) = watch::channel(HashMap::new()); + + // Simulate load updates + let mut loads = HashMap::new(); + loads.insert("http://prefill1:8080".to_string(), 10); + loads.insert("http://prefill2:8080".to_string(), 20); + loads.insert("http://decode1:8080".to_string(), 5); + loads.insert("http://decode2:8080".to_string(), 15); + + // Send the loads + tx.send(loads.clone()).unwrap(); + + // Verify receiver gets the update + let received_loads = rx.borrow(); + assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10)); + assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20)); + assert_eq!(received_loads.get("http://decode1:8080"), Some(&5)); + assert_eq!(received_loads.get("http://decode2:8080"), Some(&15)); + } + + #[test] + fn test_power_of_two_load_selection() { + // Test the power-of-two selection logic with different load scenarios + + // Scenario 1: Clear winner for both prefill and decode + let _loads = vec![ + ("prefill1", 100), + ("prefill2", 10), // Should be selected + ("decode1", 50), + ("decode2", 5), // Should be selected + ]; + + // In actual implementation, the lower load should be selected + assert!(10 < 100); + assert!(5 < 50); + + // Scenario 2: Equal loads (should select first) + let _equal_loads = vec![ + ("prefill1", 20), + ("prefill2", 20), // Either could be selected + ("decode1", 30), + ("decode2", 30), // Either could be selected + ]; + + // When loads are equal, <= comparison means first is selected + assert!(20 <= 20); + assert!(30 <= 30); + + // Scenario 3: Missing load data (should default to usize::MAX) + // This tests the unwrap_or(usize::MAX) behavior + let missing_load = usize::MAX; + assert!(10 < missing_load); + assert!(missing_load > 0); + } + + #[test] + fn test_load_monitoring_configuration() { + // Test that load monitoring is only enabled for PowerOfTwo policy + let policies = vec![ + (PDSelectionPolicy::Random, false), + (PDSelectionPolicy::PowerOfTwo, true), + ( + PDSelectionPolicy::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + }, + false, + ), + ]; + + for (policy, should_monitor) in policies { + match policy { + PDSelectionPolicy::PowerOfTwo => assert!(should_monitor), + _ => assert!(!should_monitor), + } + } + } + + #[tokio::test] + async fn test_watch_channel_behavior() { + use std::collections::HashMap; + use tokio::sync::watch; + + // Test watch channel's broadcast behavior + let (tx, rx1) = watch::channel(HashMap::new()); + let rx2 = rx1.clone(); + + // Initial state - empty map + assert!(rx1.borrow().is_empty()); + assert!(rx2.borrow().is_empty()); + + // Update 1 + let mut loads = HashMap::new(); + loads.insert("worker1".to_string(), 10); + tx.send(loads.clone()).unwrap(); + + // Both receivers see the update + assert_eq!(rx1.borrow().get("worker1"), Some(&10)); + assert_eq!(rx2.borrow().get("worker1"), Some(&10)); + + // Update 2 - overwrites previous + loads.insert("worker1".to_string(), 20); + loads.insert("worker2".to_string(), 30); + tx.send(loads).unwrap(); + + // Both receivers see the latest state + assert_eq!(rx1.borrow().get("worker1"), Some(&20)); + assert_eq!(rx2.borrow().get("worker2"), Some(&30)); + } + + // ======================================================================== + // Tests based on bench_one_batch_server.py patterns + // ======================================================================== + + #[test] + fn test_generate_request_formats() { + // Based on bench_one_batch_server.py request patterns + + // Test 1: Batch request with input_ids (most common in benchmarks) + let batch_request = json!({ + "input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 16, + "ignore_eos": true, + }, + "return_logprob": false, + "stream": true + }); + + let pd_req = PDRequest::from_json(&batch_request); + assert!(pd_req.is_stream); + assert_eq!(pd_req.batch_size, Some(3)); + + // Test 2: Request with return_logprob (critical for PD) + let logprob_request = json!({ + "input_ids": [[1, 2, 3]], + "sampling_params": { + "temperature": 0.7, + "max_new_tokens": 8, + }, + "return_logprob": true, + "stream": false + }); + + assert_eq!(logprob_request["return_logprob"], true); + assert_eq!(logprob_request["stream"], false); + + // Test 3: Large batch sizes from benchmark + let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py + for bs in batch_sizes { + let request = json!({ + "input_ids": vec![vec![1, 2, 3]; bs], + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 16, + }, + "stream": true + }); + + let pd_req = PDRequest::from_json(&request); + assert_eq!(pd_req.batch_size, Some(bs)); + } + } + + #[test] + fn test_sampling_params_handling() { + // Test various sampling parameters from bench_one_batch_server.py + let sampling_params_variations = vec![ + json!({ + "temperature": 0.0, + "max_new_tokens": 8, + "ignore_eos": true + }), + json!({ + "temperature": 0.7, + "max_new_tokens": 16, + "ignore_eos": false, + "top_p": 0.9, + "frequency_penalty": 0.5 + }), + json!({ + "temperature": 1.0, + "max_new_tokens": 64, + "json_schema": "$$ANY$$" // Structured output + }), + ]; + + for params in sampling_params_variations { + let request = json!({ + "input_ids": [[1, 2, 3]], + "sampling_params": params.clone(), + "stream": false + }); + + // Verify params are preserved + assert_eq!(request["sampling_params"], params); + } + } + + #[test] + fn test_streaming_response_parsing() { + // Test SSE format parsing from streaming responses + let sse_chunks = vec![ + "data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}", + "data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}", + "data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}", + "data: [DONE]", + ]; + + for chunk in &sse_chunks[..3] { + assert!(chunk.starts_with("data: ")); + let json_str = &chunk[6..]; // Skip "data: " + let parsed: serde_json::Value = serde_json::from_str(json_str).unwrap(); + assert!(parsed["meta_info"]["completion_tokens"].is_u64()); + } + + // Test [DONE] detection + assert_eq!(sse_chunks[3], "data: [DONE]"); + } + + #[test] + fn test_ttft_calculation() { + // Test Time To First Token calculation pattern + let first_token_response = json!({ + "text": "Hello", + "meta_info": { + "completion_tokens": 1, + "finish_reason": null + } + }); + + // TTFT is calculated when completion_tokens == 1 + assert_eq!(first_token_response["meta_info"]["completion_tokens"], 1); + assert!(first_token_response["meta_info"]["finish_reason"].is_null()); + } + + #[test] + fn test_throughput_metrics() { + // Test throughput calculation patterns from bench_one_batch_server.py + let batch_size = 16; + let input_len = 1024; + let output_len = 16; + let ttft = 0.5; // seconds + let total_latency = 2.0; // seconds + + // Input throughput = batch_size * input_len / ttft + let input_throughput = (batch_size as f64) * (input_len as f64) / ttft; + assert!((input_throughput - 32768.0).abs() < 0.01); + + // Output throughput = batch_size * output_len / (latency - ttft) + let output_throughput = (batch_size as f64) * (output_len as f64) / (total_latency - ttft); + assert!((output_throughput - 170.67).abs() < 0.01); + } + + #[test] + fn test_error_response_handling() { + // Test error response format from bench_one_batch_server.py + let error_response = json!({ + "error": "Request has failed. Invalid input format." + }); + + assert!(error_response.get("error").is_some()); + assert!(error_response["error"].as_str().unwrap().contains("failed")); + } + + #[test] + fn test_structured_output_request() { + // Test structured output format (json_schema) + let structured_request = json!({ + "text": "What is the capital of France? Answer in JSON.", + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 64, + "json_schema": "$$ANY$$" + }, + "stream": false + }); + + assert_eq!( + structured_request["sampling_params"]["json_schema"], + "$$ANY$$" + ); + } + + #[test] + fn test_bootstrap_injection_with_benchmark_requests() { + // Test bootstrap injection with actual benchmark request patterns + let mut benchmark_request = json!({ + "input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16 + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 8, + "ignore_eos": true + }, + "return_logprob": true, + "stream": true + }); + + // Simulate bootstrap injection + let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); + let batch_size = 16; + + benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]); + benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]); + benchmark_request["bootstrap_room"] = + json!((0..batch_size).map(|_| 12345u64).collect::>()); + + // Verify bootstrap fields match batch size + assert_eq!( + benchmark_request["bootstrap_host"] + .as_array() + .unwrap() + .len(), + batch_size + ); + assert_eq!( + benchmark_request["bootstrap_port"] + .as_array() + .unwrap() + .len(), + batch_size + ); + assert_eq!( + benchmark_request["bootstrap_room"] + .as_array() + .unwrap() + .len(), + batch_size + ); + + // Verify original fields are preserved + assert_eq!(benchmark_request["return_logprob"], true); + assert_eq!(benchmark_request["stream"], true); + } + + #[test] + fn test_server_info_response_format() { + // Test server info format expected by bench_one_batch_server.py + let server_info = json!({ + "internal_states": [{ + "avg_spec_accept_length": 3.5, + "last_gen_throughput": 2048.5, + "load": 16 + }], + "prefill": [ + {"url": "http://prefill1:8080", "load": 10}, + {"url": "http://prefill2:8080", "load": 20} + ], + "decode": [ + {"url": "http://decode1:8080", "load": 5}, + {"url": "http://decode2:8080", "load": 15} + ] + }); + + // Verify structure matches what benchmark expects + assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64()); + assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64()); + assert!(server_info["prefill"].is_array()); + assert!(server_info["decode"].is_array()); + } + + // ======================================================================== + // Comprehensive Endpoint Coverage Test + // ======================================================================== + + #[test] + fn test_pd_endpoints_coverage() { + // Document all endpoints from Python mini_lb.py and verify implementation status + let implemented_endpoints = vec![ + ("/health", "GET", true), + ("/health_generate", "GET", true), // Note: Python uses POST, we use GET + ("/get_server_info", "GET", true), + ("/v1/models", "GET", true), + ("/get_model_info", "GET", true), + ("/generate", "POST", true), + ("/v1/chat/completions", "POST", true), + ("/v1/completions", "POST", true), + ("/flush_cache", "POST", true), + ("/get_loads", "GET", true), + ("/register", "POST", false), // NOT IMPLEMENTED - needs dynamic worker management + ]; + + let implemented_count = implemented_endpoints + .iter() + .filter(|(_, _, impl_status)| *impl_status) + .count(); + let total_count = implemented_endpoints.len(); + + // We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2) + assert_eq!(implemented_count, 10); + assert_eq!(total_count, 11); + + // Document the missing endpoint + let missing: Vec<_> = implemented_endpoints + .iter() + .filter(|(_, _, impl_status)| !impl_status) + .map(|(endpoint, method, _)| format!("{} {}", method, endpoint)) + .collect(); + + assert_eq!(missing, vec!["POST /register"]); + } + + #[test] + fn test_large_batch_bootstrap_injection() { + // Test bootstrap injection performance with very large batches + // This simulates the bench_one_batch_server.py scenario + let large_batch_sizes = vec![1024, 4096, 8192]; + + for batch_size in large_batch_sizes { + let start = std::time::Instant::now(); + + // Simulate a large batch request + let mut large_batch_request = json!({ + "input_ids": vec![vec![1, 2, 3, 4]; batch_size], + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 16, + }, + "stream": true + }); + + // Simulate bootstrap injection + let prefill_info = + EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000)); + + large_batch_request["bootstrap_host"] = + json!(vec![prefill_info.get_hostname(); batch_size]); + large_batch_request["bootstrap_port"] = + json!(vec![prefill_info.bootstrap_port; batch_size]); + large_batch_request["bootstrap_room"] = json!((0..batch_size) + .map(|_| rand::thread_rng().gen::()) + .collect::>()); + + let elapsed = start.elapsed(); + + // Verify bootstrap fields are correctly sized + assert_eq!( + large_batch_request["bootstrap_host"] + .as_array() + .unwrap() + .len(), + batch_size + ); + assert_eq!( + large_batch_request["bootstrap_port"] + .as_array() + .unwrap() + .len(), + batch_size + ); + assert_eq!( + large_batch_request["bootstrap_room"] + .as_array() + .unwrap() + .len(), + batch_size + ); + + // Bootstrap injection should be reasonably fast even for large batches + println!( + "Bootstrap injection for batch_size {} took {:?}", + batch_size, elapsed + ); + assert!( + elapsed.as_millis() < 1000, + "Bootstrap injection took too long for batch size {}", + batch_size + ); + } + } + + #[test] + fn test_payload_size_calculation() { + // Test payload size estimation for bench_one_batch_server.py scenarios + let test_cases = vec![ + (1, 1024, 16), // Small batch + (16, 1024, 16), // Medium batch + (64, 1024, 16), // Large batch + (8192, 4096, 5), // Benchmark scenario + ]; + + for (batch_size, input_len, _output_len) in test_cases { + // Estimate payload size (rough calculation) + // Each token is ~4 bytes (i32), plus JSON overhead + let tokens_size = batch_size * input_len * 4; // 4 bytes per token + let json_overhead = batch_size * 100; // ~100 bytes overhead per request + let total_size = tokens_size + json_overhead; + + println!( + "Batch size: {}, Input len: {}, Estimated payload: {} MB", + batch_size, + input_len, + total_size / (1024 * 1024) + ); + + // For the benchmark case (8192, 4096), this should be ~134 MB + if batch_size == 8192 && input_len == 4096 { + assert!( + total_size > 100 * 1024 * 1024, + "Benchmark payload should be > 100MB" + ); + assert!( + total_size < 200 * 1024 * 1024, + "Benchmark payload should be < 200MB" + ); + } + } + } + + #[test] + fn test_policy_type_to_pd_selection_policy_mapping() { + // Document the mapping from PolicyType to PDSelectionPolicy + // This mapping happens in lib.rs when pd_disaggregated=true + + // PolicyType::Random -> PDSelectionPolicy::Random + // PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo + // PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... } + // PolicyType::RoundRobin -> ERROR (not supported in PD mode) + + // Test that PDSelectionPolicy doesn't include RoundRobin + let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware + assert_eq!( + pd_policy_count, 3, + "PDSelectionPolicy should have exactly 3 variants" + ); + + // Verify that each PDSelectionPolicy variant can be created + let _random = PDSelectionPolicy::Random; + let _po2 = PDSelectionPolicy::PowerOfTwo; + let _cache_aware = PDSelectionPolicy::CacheAware { + cache_threshold: 0.5, + balance_abs_threshold: 32, + balance_rel_threshold: 1.1, + }; + } +}