Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -218,15 +218,39 @@ async def get_server_info():
|
|||||||
)
|
)
|
||||||
prefill_infos = []
|
prefill_infos = []
|
||||||
decode_infos = []
|
decode_infos = []
|
||||||
|
all_internal_states = []
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
for server in chain(prefill_servers):
|
for server in chain(prefill_servers):
|
||||||
server_info = await session.get(f"{server}/get_server_info")
|
server_info = await session.get(f"{server}/get_server_info")
|
||||||
prefill_infos.append(await server_info.json())
|
prefill_infos.append(await server_info.json())
|
||||||
for server in chain(decode_servers):
|
for server in chain(decode_servers):
|
||||||
server_info = await session.get(f"{server}/get_server_info")
|
server_info = await session.get(f"{server}/get_server_info")
|
||||||
decode_infos.append(await server_info.json())
|
info_json = await server_info.json()
|
||||||
|
decode_infos.append(info_json)
|
||||||
|
# Extract internal_states from decode servers
|
||||||
|
if "internal_states" in info_json:
|
||||||
|
all_internal_states.extend(info_json["internal_states"])
|
||||||
|
|
||||||
return {"prefill": prefill_infos, "decode": decode_infos}
|
# Return format expected by bench_one_batch_server.py
|
||||||
|
if all_internal_states:
|
||||||
|
return {
|
||||||
|
"internal_states": all_internal_states,
|
||||||
|
"prefill": prefill_infos,
|
||||||
|
"decode": decode_infos,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Fallback with dummy data if no internal states found
|
||||||
|
return {
|
||||||
|
"internal_states": [
|
||||||
|
{
|
||||||
|
"last_gen_throughput": 0.0,
|
||||||
|
"avg_spec_accept_length": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"prefill": prefill_infos,
|
||||||
|
"decode": decode_infos,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/get_model_info")
|
@app.get("/get_model_info")
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
|
|||||||
clap = { version = "4.4", features = ["derive"] }
|
clap = { version = "4.4", features = ["derive"] }
|
||||||
bytes = "1.8.0"
|
bytes = "1.8.0"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
reqwest = { version = "0.12.8", features = ["stream", "blocking"] }
|
reqwest = { version = "0.12.8", features = ["stream", "blocking", "json"] }
|
||||||
futures-util = "0.3"
|
futures-util = "0.3"
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
pyo3 = { version = "0.22.5", features = ["extension-module"] }
|
||||||
@@ -33,6 +33,8 @@ futures = "0.3"
|
|||||||
# Added for metrics
|
# Added for metrics
|
||||||
metrics = "0.24.2"
|
metrics = "0.24.2"
|
||||||
metrics-exporter-prometheus = "0.17.0"
|
metrics-exporter-prometheus = "0.17.0"
|
||||||
|
# Added for request tracing
|
||||||
|
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||||
[profile.release]
|
[profile.release]
|
||||||
lto = "thin"
|
lto = "thin"
|
||||||
codegen-units = 1
|
codegen-units = 1
|
||||||
|
|||||||
@@ -31,6 +31,13 @@ class RouterArgs:
|
|||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
port: int = 30000
|
port: int = 30000
|
||||||
|
|
||||||
|
# PD-specific configuration
|
||||||
|
pd_disaggregated: bool = False # Enable PD disaggregated mode
|
||||||
|
prefill_urls: List[tuple] = dataclasses.field(
|
||||||
|
default_factory=list
|
||||||
|
) # List of (url, bootstrap_port)
|
||||||
|
decode_urls: List[str] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
# Routing policy
|
# Routing policy
|
||||||
policy: str = "cache_aware"
|
policy: str = "cache_aware"
|
||||||
worker_startup_timeout_secs: int = 300
|
worker_startup_timeout_secs: int = 300
|
||||||
@@ -40,7 +47,7 @@ class RouterArgs:
|
|||||||
balance_rel_threshold: float = 1.0001
|
balance_rel_threshold: float = 1.0001
|
||||||
eviction_interval: int = 60
|
eviction_interval: int = 60
|
||||||
max_tree_size: int = 2**24
|
max_tree_size: int = 2**24
|
||||||
max_payload_size: int = 4 * 1024 * 1024 # 4MB
|
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
log_dir: Optional[str] = None
|
log_dir: Optional[str] = None
|
||||||
# Service discovery configuration
|
# Service discovery configuration
|
||||||
@@ -95,8 +102,29 @@ class RouterArgs:
|
|||||||
f"--{prefix}policy",
|
f"--{prefix}policy",
|
||||||
type=str,
|
type=str,
|
||||||
default=RouterArgs.policy,
|
default=RouterArgs.policy,
|
||||||
choices=["random", "round_robin", "cache_aware"],
|
choices=["random", "round_robin", "cache_aware", "power_of_two"],
|
||||||
help="Load balancing policy to use",
|
help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode",
|
||||||
|
)
|
||||||
|
|
||||||
|
# PD-specific arguments
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}pd-disaggregated",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prefill",
|
||||||
|
nargs=2,
|
||||||
|
action="append",
|
||||||
|
metavar=("URL", "BOOTSTRAP_PORT"),
|
||||||
|
help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}decode",
|
||||||
|
nargs=1,
|
||||||
|
action="append",
|
||||||
|
metavar=("URL",),
|
||||||
|
help="Decode server URL. Can be specified multiple times.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{prefix}worker-startup-timeout-secs",
|
f"--{prefix}worker-startup-timeout-secs",
|
||||||
@@ -205,11 +233,19 @@ class RouterArgs:
|
|||||||
use_router_prefix: If True, look for arguments with 'router-' prefix
|
use_router_prefix: If True, look for arguments with 'router-' prefix
|
||||||
"""
|
"""
|
||||||
prefix = "router_" if use_router_prefix else ""
|
prefix = "router_" if use_router_prefix else ""
|
||||||
worker_urls = args.worker_urls if args.worker_urls is not None else []
|
worker_urls = getattr(args, "worker_urls", [])
|
||||||
|
|
||||||
|
# Parse PD URLs
|
||||||
|
prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None))
|
||||||
|
decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None))
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
worker_urls=worker_urls,
|
worker_urls=worker_urls,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
|
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
|
||||||
|
prefill_urls=prefill_urls,
|
||||||
|
decode_urls=decode_urls,
|
||||||
policy=getattr(args, f"{prefix}policy"),
|
policy=getattr(args, f"{prefix}policy"),
|
||||||
worker_startup_timeout_secs=getattr(
|
worker_startup_timeout_secs=getattr(
|
||||||
args, f"{prefix}worker_startup_timeout_secs"
|
args, f"{prefix}worker_startup_timeout_secs"
|
||||||
@@ -247,6 +283,46 @@ class RouterArgs:
|
|||||||
selector[key] = value
|
selector[key] = value
|
||||||
return selector
|
return selector
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_prefill_urls(prefill_list):
|
||||||
|
"""Parse prefill URLs from --prefill arguments.
|
||||||
|
|
||||||
|
Format: --prefill URL BOOTSTRAP_PORT
|
||||||
|
Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none
|
||||||
|
"""
|
||||||
|
if not prefill_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
prefill_urls = []
|
||||||
|
for url, bootstrap_port_str in prefill_list:
|
||||||
|
# Handle 'none' as None
|
||||||
|
if bootstrap_port_str.lower() == "none":
|
||||||
|
bootstrap_port = None
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
bootstrap_port = int(bootstrap_port_str)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'"
|
||||||
|
)
|
||||||
|
|
||||||
|
prefill_urls.append((url, bootstrap_port))
|
||||||
|
|
||||||
|
return prefill_urls
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_decode_urls(decode_list):
|
||||||
|
"""Parse decode URLs from --decode arguments.
|
||||||
|
|
||||||
|
Format: --decode URL
|
||||||
|
Example: --decode http://decode1:8081 --decode http://decode2:8081
|
||||||
|
"""
|
||||||
|
if not decode_list:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# decode_list is a list of single-element lists due to nargs=1
|
||||||
|
return [url[0] for url in decode_list]
|
||||||
|
|
||||||
|
|
||||||
def policy_from_str(policy_str: str) -> PolicyType:
|
def policy_from_str(policy_str: str) -> PolicyType:
|
||||||
"""Convert policy string to PolicyType enum."""
|
"""Convert policy string to PolicyType enum."""
|
||||||
@@ -254,6 +330,7 @@ def policy_from_str(policy_str: str) -> PolicyType:
|
|||||||
"random": PolicyType.Random,
|
"random": PolicyType.Random,
|
||||||
"round_robin": PolicyType.RoundRobin,
|
"round_robin": PolicyType.RoundRobin,
|
||||||
"cache_aware": PolicyType.CacheAware,
|
"cache_aware": PolicyType.CacheAware,
|
||||||
|
"power_of_two": PolicyType.PowerOfTwo,
|
||||||
}
|
}
|
||||||
return policy_map[policy_str]
|
return policy_map[policy_str]
|
||||||
|
|
||||||
@@ -277,8 +354,19 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
else:
|
else:
|
||||||
router_args = args
|
router_args = args
|
||||||
|
|
||||||
|
# Validate configuration based on mode
|
||||||
|
if router_args.pd_disaggregated:
|
||||||
|
# Validate PD configuration
|
||||||
|
if not router_args.prefill_urls:
|
||||||
|
raise ValueError("PD disaggregated mode requires --prefill")
|
||||||
|
if not router_args.decode_urls:
|
||||||
|
raise ValueError("PD disaggregated mode requires --decode")
|
||||||
|
|
||||||
|
# Create router with unified constructor
|
||||||
router = Router(
|
router = Router(
|
||||||
worker_urls=router_args.worker_urls,
|
worker_urls=(
|
||||||
|
router_args.worker_urls if not router_args.pd_disaggregated else []
|
||||||
|
),
|
||||||
host=router_args.host,
|
host=router_args.host,
|
||||||
port=router_args.port,
|
port=router_args.port,
|
||||||
policy=policy_from_str(router_args.policy),
|
policy=policy_from_str(router_args.policy),
|
||||||
@@ -298,6 +386,13 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||||
prometheus_port=router_args.prometheus_port,
|
prometheus_port=router_args.prometheus_port,
|
||||||
prometheus_host=router_args.prometheus_host,
|
prometheus_host=router_args.prometheus_host,
|
||||||
|
pd_disaggregated=router_args.pd_disaggregated,
|
||||||
|
prefill_urls=(
|
||||||
|
router_args.prefill_urls if router_args.pd_disaggregated else None
|
||||||
|
),
|
||||||
|
decode_urls=(
|
||||||
|
router_args.decode_urls if router_args.pd_disaggregated else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
router.start()
|
router.start()
|
||||||
@@ -326,8 +421,14 @@ This launcher enables starting a router with individual worker instances. It is
|
|||||||
multi-node setups or when you want to start workers and router separately.
|
multi-node setups or when you want to start workers and router separately.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
# Regular mode
|
||||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 --cache-threshold 0.7 --balance-abs-threshold 64 --balance-rel-threshold 1.2
|
|
||||||
|
# PD disaggregated mode
|
||||||
|
python -m sglang_router.launch_router --pd-disaggregated \\
|
||||||
|
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||||
|
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||||
|
--policy cache_aware
|
||||||
|
|
||||||
""",
|
""",
|
||||||
formatter_class=CustomHelpFormatter,
|
formatter_class=CustomHelpFormatter,
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ class Router:
|
|||||||
- PolicyType.Random: Randomly select workers
|
- PolicyType.Random: Randomly select workers
|
||||||
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
- PolicyType.RoundRobin: Distribute requests in round-robin fashion
|
||||||
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
- PolicyType.CacheAware: Distribute requests based on cache state and load balance
|
||||||
|
- PolicyType.PowerOfTwo: Select best of two random workers based on load (PD mode only)
|
||||||
host: Host address to bind the router server. Default: '127.0.0.1'
|
host: Host address to bind the router server. Default: '127.0.0.1'
|
||||||
port: Port number to bind the router server. Default: 3001
|
port: Port number to bind the router server. Default: 3001
|
||||||
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300
|
||||||
@@ -28,7 +29,7 @@ class Router:
|
|||||||
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
AND max_load > min_load * rel_threshold. Otherwise, use cache aware. Default: 1.0001
|
||||||
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
eviction_interval_secs: Interval in seconds between cache eviction operations in cache-aware
|
||||||
routing. Default: 60
|
routing. Default: 60
|
||||||
max_payload_size: Maximum payload size in bytes. Default: 4MB
|
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||||
verbose: Enable verbose logging. Default: False
|
verbose: Enable verbose logging. Default: False
|
||||||
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
||||||
@@ -42,6 +43,9 @@ class Router:
|
|||||||
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
||||||
prometheus_port: Port to expose Prometheus metrics. Default: None
|
prometheus_port: Port to expose Prometheus metrics. Default: None
|
||||||
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
||||||
|
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||||
|
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||||
|
decode_urls: List of URLs for decode servers (PD mode only)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -57,7 +61,7 @@ class Router:
|
|||||||
balance_rel_threshold: float = 1.0001,
|
balance_rel_threshold: float = 1.0001,
|
||||||
eviction_interval_secs: int = 60,
|
eviction_interval_secs: int = 60,
|
||||||
max_tree_size: int = 2**24,
|
max_tree_size: int = 2**24,
|
||||||
max_payload_size: int = 4 * 1024 * 1024, # 4MB
|
max_payload_size: int = 256 * 1024 * 1024, # 256MB
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
log_dir: Optional[str] = None,
|
log_dir: Optional[str] = None,
|
||||||
service_discovery: bool = False,
|
service_discovery: bool = False,
|
||||||
@@ -66,6 +70,9 @@ class Router:
|
|||||||
service_discovery_namespace: Optional[str] = None,
|
service_discovery_namespace: Optional[str] = None,
|
||||||
prometheus_port: Optional[int] = None,
|
prometheus_port: Optional[int] = None,
|
||||||
prometheus_host: Optional[str] = None,
|
prometheus_host: Optional[str] = None,
|
||||||
|
pd_disaggregated: bool = False,
|
||||||
|
prefill_urls: Optional[List[tuple]] = None,
|
||||||
|
decode_urls: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
if selector is None:
|
if selector is None:
|
||||||
selector = {}
|
selector = {}
|
||||||
@@ -91,6 +98,9 @@ class Router:
|
|||||||
service_discovery_namespace=service_discovery_namespace,
|
service_discovery_namespace=service_discovery_namespace,
|
||||||
prometheus_port=prometheus_port,
|
prometheus_port=prometheus_port,
|
||||||
prometheus_host=prometheus_host,
|
prometheus_host=prometheus_host,
|
||||||
|
pd_disaggregated=pd_disaggregated,
|
||||||
|
prefill_urls=prefill_urls,
|
||||||
|
decode_urls=decode_urls,
|
||||||
)
|
)
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
|
|||||||
@@ -35,13 +35,21 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
balance_rel_threshold=1.0001,
|
balance_rel_threshold=1.0001,
|
||||||
eviction_interval=60,
|
eviction_interval=60,
|
||||||
max_tree_size=2**24,
|
max_tree_size=2**24,
|
||||||
max_payload_size=4 * 1024 * 1024, # 4MB
|
max_payload_size=256 * 1024 * 1024, # 256MB
|
||||||
verbose=False,
|
verbose=False,
|
||||||
log_dir=None,
|
log_dir=None,
|
||||||
service_discovery=False,
|
service_discovery=False,
|
||||||
selector=None,
|
selector=None,
|
||||||
service_discovery_port=80,
|
service_discovery_port=80,
|
||||||
service_discovery_namespace=None,
|
service_discovery_namespace=None,
|
||||||
|
prometheus_port=None,
|
||||||
|
prometheus_host=None,
|
||||||
|
# PD-specific attributes
|
||||||
|
pd_disaggregated=False,
|
||||||
|
prefill=None,
|
||||||
|
decode=None,
|
||||||
|
# Keep worker_urls for regular mode
|
||||||
|
worker_urls=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_router_args(self, **kwargs):
|
def create_router_args(self, **kwargs):
|
||||||
@@ -81,7 +89,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
|
|
||||||
def test_launch_router_with_empty_worker_urls(self):
|
def test_launch_router_with_empty_worker_urls(self):
|
||||||
args = self.create_router_args(worker_urls=[])
|
args = self.create_router_args(worker_urls=[])
|
||||||
self.run_router_process(args)
|
self.run_router_process(args) # Expected error
|
||||||
|
|
||||||
def test_launch_router_with_service_discovery(self):
|
def test_launch_router_with_service_discovery(self):
|
||||||
# Test router startup with service discovery enabled but no selectors
|
# Test router startup with service discovery enabled but no selectors
|
||||||
@@ -100,6 +108,112 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.run_router_process(args)
|
self.run_router_process(args)
|
||||||
|
|
||||||
|
def test_launch_router_pd_mode_basic(self):
|
||||||
|
"""Test basic PD router functionality without actually starting servers."""
|
||||||
|
# This test just verifies the PD router can be created and configured
|
||||||
|
# without actually starting it (which would require real prefill/decode servers)
|
||||||
|
from sglang_router import Router
|
||||||
|
from sglang_router.launch_router import RouterArgs
|
||||||
|
from sglang_router_rs import PolicyType
|
||||||
|
|
||||||
|
# Test RouterArgs parsing for PD mode
|
||||||
|
# Simulate the parsed args structure from argparse with action="append"
|
||||||
|
args = self.create_router_args(
|
||||||
|
pd_disaggregated=True,
|
||||||
|
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
||||||
|
prefill=[
|
||||||
|
["http://prefill1:8080", "9000"],
|
||||||
|
["http://prefill2:8080", "none"],
|
||||||
|
],
|
||||||
|
decode=[
|
||||||
|
["http://decode1:8081"],
|
||||||
|
["http://decode2:8081"],
|
||||||
|
],
|
||||||
|
worker_urls=[], # Empty for PD mode
|
||||||
|
)
|
||||||
|
|
||||||
|
router_args = RouterArgs.from_cli_args(args)
|
||||||
|
self.assertTrue(router_args.pd_disaggregated)
|
||||||
|
self.assertEqual(router_args.policy, "power_of_two")
|
||||||
|
self.assertEqual(len(router_args.prefill_urls), 2)
|
||||||
|
self.assertEqual(len(router_args.decode_urls), 2)
|
||||||
|
|
||||||
|
# Verify the parsed URLs and bootstrap ports
|
||||||
|
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
|
||||||
|
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
|
||||||
|
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
|
||||||
|
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
|
||||||
|
|
||||||
|
# Test Router creation in PD mode
|
||||||
|
router = Router(
|
||||||
|
worker_urls=[], # Empty for PD mode
|
||||||
|
pd_disaggregated=True,
|
||||||
|
prefill_urls=[
|
||||||
|
("http://prefill1:8080", 9000),
|
||||||
|
("http://prefill2:8080", None),
|
||||||
|
],
|
||||||
|
decode_urls=["http://decode1:8081", "http://decode2:8081"],
|
||||||
|
policy=PolicyType.CacheAware,
|
||||||
|
host="127.0.0.1",
|
||||||
|
port=3001,
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(router)
|
||||||
|
|
||||||
|
def test_policy_validation(self):
|
||||||
|
"""Test that policy validation works correctly for PD and regular modes."""
|
||||||
|
from sglang_router.launch_router import RouterArgs, launch_router
|
||||||
|
|
||||||
|
# Test 1: PowerOfTwo is only valid in PD mode
|
||||||
|
args = self.create_router_args(
|
||||||
|
pd_disaggregated=False,
|
||||||
|
policy="power_of_two",
|
||||||
|
worker_urls=["http://localhost:8000"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should raise error
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
launch_router(args)
|
||||||
|
self.assertIn(
|
||||||
|
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||||
|
str(cm.exception),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 2: RoundRobin is not valid in PD mode
|
||||||
|
args = self.create_router_args(
|
||||||
|
pd_disaggregated=True,
|
||||||
|
policy="round_robin",
|
||||||
|
prefill=[["http://prefill1:8080", "9000"]],
|
||||||
|
decode=[["http://decode1:8081"]],
|
||||||
|
worker_urls=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should raise error
|
||||||
|
with self.assertRaises(ValueError) as cm:
|
||||||
|
launch_router(args)
|
||||||
|
self.assertIn(
|
||||||
|
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||||
|
str(cm.exception),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test 3: Valid combinations should not raise errors
|
||||||
|
# Regular mode with RoundRobin
|
||||||
|
args = self.create_router_args(
|
||||||
|
pd_disaggregated=False,
|
||||||
|
policy="round_robin",
|
||||||
|
worker_urls=["http://localhost:8000"],
|
||||||
|
)
|
||||||
|
# This should not raise (though it may fail to connect)
|
||||||
|
|
||||||
|
# PD mode with PowerOfTwo
|
||||||
|
args = self.create_router_args(
|
||||||
|
pd_disaggregated=True,
|
||||||
|
policy="power_of_two",
|
||||||
|
prefill=[["http://prefill1:8080", "9000"]],
|
||||||
|
decode=[["http://decode1:8081"]],
|
||||||
|
worker_urls=[],
|
||||||
|
)
|
||||||
|
# This should not raise (though it may fail to connect)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
pub mod logging;
|
pub mod logging;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
pub mod openai_api_types;
|
||||||
|
pub mod pd_router;
|
||||||
|
pub mod pd_types;
|
||||||
pub mod prometheus;
|
pub mod prometheus;
|
||||||
|
pub mod request_adapter;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod server;
|
pub mod server;
|
||||||
pub mod service_discovery;
|
pub mod service_discovery;
|
||||||
@@ -14,6 +18,7 @@ pub enum PolicyType {
|
|||||||
Random,
|
Random,
|
||||||
RoundRobin,
|
RoundRobin,
|
||||||
CacheAware,
|
CacheAware,
|
||||||
|
PowerOfTwo, // Moved from PD-specific, now shared
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyclass]
|
#[pyclass]
|
||||||
@@ -39,6 +44,12 @@ struct Router {
|
|||||||
service_discovery_namespace: Option<String>,
|
service_discovery_namespace: Option<String>,
|
||||||
prometheus_port: Option<u16>,
|
prometheus_port: Option<u16>,
|
||||||
prometheus_host: Option<String>,
|
prometheus_host: Option<String>,
|
||||||
|
request_timeout_secs: u64,
|
||||||
|
// PD mode flag
|
||||||
|
pd_disaggregated: bool,
|
||||||
|
// PD-specific fields (only used when pd_disaggregated is true)
|
||||||
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
|
decode_urls: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymethods]
|
#[pymethods]
|
||||||
@@ -56,7 +67,7 @@ impl Router {
|
|||||||
balance_rel_threshold = 1.0001,
|
balance_rel_threshold = 1.0001,
|
||||||
eviction_interval_secs = 60,
|
eviction_interval_secs = 60,
|
||||||
max_tree_size = 2usize.pow(24),
|
max_tree_size = 2usize.pow(24),
|
||||||
max_payload_size = 4 * 1024 * 1024,
|
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||||
verbose = false,
|
verbose = false,
|
||||||
log_dir = None,
|
log_dir = None,
|
||||||
service_discovery = false,
|
service_discovery = false,
|
||||||
@@ -64,7 +75,11 @@ impl Router {
|
|||||||
service_discovery_port = 80,
|
service_discovery_port = 80,
|
||||||
service_discovery_namespace = None,
|
service_discovery_namespace = None,
|
||||||
prometheus_port = None,
|
prometheus_port = None,
|
||||||
prometheus_host = None
|
prometheus_host = None,
|
||||||
|
request_timeout_secs = 600, // Add configurable request timeout
|
||||||
|
pd_disaggregated = false, // New flag for PD mode
|
||||||
|
prefill_urls = None,
|
||||||
|
decode_urls = None
|
||||||
))]
|
))]
|
||||||
fn new(
|
fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
@@ -87,6 +102,10 @@ impl Router {
|
|||||||
service_discovery_namespace: Option<String>,
|
service_discovery_namespace: Option<String>,
|
||||||
prometheus_port: Option<u16>,
|
prometheus_port: Option<u16>,
|
||||||
prometheus_host: Option<String>,
|
prometheus_host: Option<String>,
|
||||||
|
request_timeout_secs: u64,
|
||||||
|
pd_disaggregated: bool,
|
||||||
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
|
decode_urls: Option<Vec<String>>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
Ok(Router {
|
Ok(Router {
|
||||||
host,
|
host,
|
||||||
@@ -109,28 +128,75 @@ impl Router {
|
|||||||
service_discovery_namespace,
|
service_discovery_namespace,
|
||||||
prometheus_port,
|
prometheus_port,
|
||||||
prometheus_host,
|
prometheus_host,
|
||||||
|
request_timeout_secs,
|
||||||
|
pd_disaggregated,
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start(&self) -> PyResult<()> {
|
fn start(&self) -> PyResult<()> {
|
||||||
let policy_config = match &self.policy {
|
let policy_config = if self.pd_disaggregated {
|
||||||
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
// PD mode - map PolicyType to PDSelectionPolicy
|
||||||
|
let pd_selection_policy = match &self.policy {
|
||||||
|
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
||||||
|
PolicyType::PowerOfTwo => pd_types::PDSelectionPolicy::PowerOfTwo,
|
||||||
|
PolicyType::CacheAware => pd_types::PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold: self.cache_threshold,
|
||||||
|
balance_abs_threshold: self.balance_abs_threshold,
|
||||||
|
balance_rel_threshold: self.balance_rel_threshold,
|
||||||
|
},
|
||||||
|
PolicyType::RoundRobin => {
|
||||||
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||||
|
"RoundRobin policy is not supported in PD disaggregated mode",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let prefill_urls = self.prefill_urls.as_ref().ok_or_else(|| {
|
||||||
|
pyo3::exceptions::PyValueError::new_err(
|
||||||
|
"PD disaggregated mode requires prefill_urls",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let decode_urls = self.decode_urls.as_ref().ok_or_else(|| {
|
||||||
|
pyo3::exceptions::PyValueError::new_err(
|
||||||
|
"PD disaggregated mode requires decode_urls",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
router::PolicyConfig::PrefillDecodeConfig {
|
||||||
|
selection_policy: pd_selection_policy,
|
||||||
|
prefill_urls: prefill_urls.clone(),
|
||||||
|
decode_urls: decode_urls.clone(),
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
interval_secs: self.worker_startup_check_interval,
|
interval_secs: self.worker_startup_check_interval,
|
||||||
},
|
}
|
||||||
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
} else {
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
// Regular mode
|
||||||
interval_secs: self.worker_startup_check_interval,
|
match &self.policy {
|
||||||
},
|
PolicyType::Random => router::PolicyConfig::RandomConfig {
|
||||||
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
timeout_secs: self.worker_startup_timeout_secs,
|
interval_secs: self.worker_startup_check_interval,
|
||||||
interval_secs: self.worker_startup_check_interval,
|
},
|
||||||
cache_threshold: self.cache_threshold,
|
PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig {
|
||||||
balance_abs_threshold: self.balance_abs_threshold,
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
balance_rel_threshold: self.balance_rel_threshold,
|
interval_secs: self.worker_startup_check_interval,
|
||||||
eviction_interval_secs: self.eviction_interval_secs,
|
},
|
||||||
max_tree_size: self.max_tree_size,
|
PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig {
|
||||||
},
|
timeout_secs: self.worker_startup_timeout_secs,
|
||||||
|
interval_secs: self.worker_startup_check_interval,
|
||||||
|
cache_threshold: self.cache_threshold,
|
||||||
|
balance_abs_threshold: self.balance_abs_threshold,
|
||||||
|
balance_rel_threshold: self.balance_rel_threshold,
|
||||||
|
eviction_interval_secs: self.eviction_interval_secs,
|
||||||
|
max_tree_size: self.max_tree_size,
|
||||||
|
},
|
||||||
|
PolicyType::PowerOfTwo => {
|
||||||
|
return Err(pyo3::exceptions::PyValueError::new_err(
|
||||||
|
"PowerOfTwo policy is only supported in PD disaggregated mode",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create service discovery config if enabled
|
// Create service discovery config if enabled
|
||||||
@@ -166,6 +232,7 @@ impl Router {
|
|||||||
log_dir: self.log_dir.clone(),
|
log_dir: self.log_dir.clone(),
|
||||||
service_discovery_config,
|
service_discovery_config,
|
||||||
prometheus_config,
|
prometheus_config,
|
||||||
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
.map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?;
|
||||||
|
|||||||
704
sgl-router/src/openai_api_types.rs
Normal file
704
sgl-router/src/openai_api_types.rs
Normal file
@@ -0,0 +1,704 @@
|
|||||||
|
// OpenAI-compatible API types for text generation
|
||||||
|
// Based on OpenAI's API specification: https://platform.openai.com/docs/api-reference
|
||||||
|
// Reference: Azure OpenAI API documentation which follows OpenAI's specification
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
/// Common trait for all generation requests
|
||||||
|
pub trait GenerationRequest: Send + Sync {
|
||||||
|
/// Check if the request is for streaming
|
||||||
|
fn is_stream(&self) -> bool;
|
||||||
|
|
||||||
|
/// Get the model name if specified
|
||||||
|
fn get_model(&self) -> Option<&str>;
|
||||||
|
|
||||||
|
/// Extract text content for routing decisions
|
||||||
|
fn extract_text_for_routing(&self) -> String;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Completions API (v1/completions) - DEPRECATED but still supported =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionRequest {
|
||||||
|
/// ID of the model to use (required for OpenAI, optional for some implementations, such as SGLang)
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// The prompt(s) to generate completions for
|
||||||
|
pub prompt: StringOrArray,
|
||||||
|
|
||||||
|
/// The suffix that comes after a completion of inserted text
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature (nucleus sampling)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// How many completions to generate for each prompt
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// Whether to stream back partial progress
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Include the log probabilities on the logprobs most likely tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// Echo back the prompt in addition to the completion
|
||||||
|
#[serde(default)]
|
||||||
|
pub echo: bool,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StringOrArray>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Generates best_of completions server-side and returns the "best"
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of: Option<u32>,
|
||||||
|
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<HashMap<String, f32>>,
|
||||||
|
|
||||||
|
/// A unique identifier representing your end-user
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
|
||||||
|
/// If specified, our system will make a best effort to sample deterministically
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerationRequest for CompletionRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model(&self) -> Option<&str> {
|
||||||
|
Some(&self.model)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
match &self.prompt {
|
||||||
|
StringOrArray::String(s) => s.clone(),
|
||||||
|
StringOrArray::Array(v) => v.join(" "),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Chat Completions API (v1/chat/completions) =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionRequest {
|
||||||
|
/// ID of the model to use
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// A list of messages comprising the conversation so far
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
|
||||||
|
/// What sampling temperature to use, between 0 and 2
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
|
||||||
|
/// An alternative to sampling with temperature
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
|
||||||
|
/// How many chat completion choices to generate for each input message
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
|
||||||
|
/// If set, partial message deltas will be sent
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Up to 4 sequences where the API will stop generating further tokens
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StringOrArray>,
|
||||||
|
|
||||||
|
/// The maximum number of tokens to generate
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// An upper bound for the number of tokens that can be generated for a completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_completion_tokens: Option<u32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
|
||||||
|
/// Modify the likelihood of specified tokens appearing in the completion
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logit_bias: Option<HashMap<String, i32>>,
|
||||||
|
|
||||||
|
/// A unique identifier representing your end-user
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub user: Option<String>,
|
||||||
|
|
||||||
|
/// If specified, our system will make a best effort to sample deterministically
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<i64>,
|
||||||
|
|
||||||
|
/// Whether to return log probabilities of the output tokens
|
||||||
|
#[serde(default)]
|
||||||
|
pub logprobs: bool,
|
||||||
|
|
||||||
|
/// An integer between 0 and 20 specifying the number of most likely tokens to return
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_logprobs: Option<u32>,
|
||||||
|
|
||||||
|
/// An object specifying the format that the model must output
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
|
||||||
|
/// A list of tools the model may call
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<Tool>>,
|
||||||
|
|
||||||
|
/// Controls which (if any) tool is called by the model
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<ToolChoice>,
|
||||||
|
|
||||||
|
/// Whether to enable parallel function calling during tool use
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
|
||||||
|
/// Deprecated: use tools instead
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub functions: Option<Vec<Function>>,
|
||||||
|
|
||||||
|
/// Deprecated: use tool_choice instead
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function_call: Option<FunctionCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ChatMessage {
|
||||||
|
System {
|
||||||
|
role: String, // "system"
|
||||||
|
content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
User {
|
||||||
|
role: String, // "user"
|
||||||
|
content: UserMessageContent,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
},
|
||||||
|
Assistant {
|
||||||
|
role: String, // "assistant"
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
function_call: Option<FunctionCallResponse>,
|
||||||
|
},
|
||||||
|
Tool {
|
||||||
|
role: String, // "tool"
|
||||||
|
content: String,
|
||||||
|
tool_call_id: String,
|
||||||
|
},
|
||||||
|
Function {
|
||||||
|
role: String, // "function"
|
||||||
|
content: String,
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum UserMessageContent {
|
||||||
|
Text(String),
|
||||||
|
Parts(Vec<ContentPart>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ContentPart {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text { text: String },
|
||||||
|
#[serde(rename = "image_url")]
|
||||||
|
ImageUrl { image_url: ImageUrl },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ImageUrl {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub detail: Option<String>, // "auto", "low", or "high"
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ResponseFormat {
|
||||||
|
#[serde(rename = "text")]
|
||||||
|
Text,
|
||||||
|
#[serde(rename = "json_object")]
|
||||||
|
JsonObject,
|
||||||
|
#[serde(rename = "json_schema")]
|
||||||
|
JsonSchema { json_schema: JsonSchemaFormat },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct JsonSchemaFormat {
|
||||||
|
pub name: String,
|
||||||
|
pub schema: Value,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub strict: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Tool {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String, // "function"
|
||||||
|
pub function: Function,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Function {
|
||||||
|
pub name: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub parameters: Value, // JSON Schema
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum ToolChoice {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
Required,
|
||||||
|
Function {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
tool_type: String, // "function"
|
||||||
|
function: FunctionChoice,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FunctionChoice {
|
||||||
|
pub name: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: String, // "function"
|
||||||
|
pub function: FunctionCallResponse,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum FunctionCall {
|
||||||
|
None,
|
||||||
|
Auto,
|
||||||
|
Function { name: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FunctionCallResponse {
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String, // JSON string
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerationRequest for ChatCompletionRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model(&self) -> Option<&str> {
|
||||||
|
Some(&self.model)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
// Extract text from messages for routing decisions
|
||||||
|
self.messages
|
||||||
|
.iter()
|
||||||
|
.filter_map(|msg| match msg {
|
||||||
|
ChatMessage::System { content, .. } => Some(content.clone()),
|
||||||
|
ChatMessage::User { content, .. } => match content {
|
||||||
|
UserMessageContent::Text(text) => Some(text.clone()),
|
||||||
|
UserMessageContent::Parts(parts) => {
|
||||||
|
let texts: Vec<String> = parts
|
||||||
|
.iter()
|
||||||
|
.filter_map(|part| match part {
|
||||||
|
ContentPart::Text { text } => Some(text.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Some(texts.join(" "))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
ChatMessage::Assistant { content, .. } => content.clone(),
|
||||||
|
ChatMessage::Tool { content, .. } => Some(content.clone()),
|
||||||
|
ChatMessage::Function { content, .. } => Some(content.clone()),
|
||||||
|
})
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(" ")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Generate API (/generate) =============
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateRequest {
|
||||||
|
/// The prompt to generate from (OpenAI style)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt: Option<StringOrArray>,
|
||||||
|
|
||||||
|
/// Text input - SGLang native format
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub text: Option<String>,
|
||||||
|
|
||||||
|
/// Input IDs for tokenized input
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub input_ids: Option<InputIds>,
|
||||||
|
|
||||||
|
/// Generation parameters
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parameters: Option<GenerateParameters>,
|
||||||
|
|
||||||
|
/// Sampling parameters (sglang style)
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sampling_params: Option<SamplingParams>,
|
||||||
|
|
||||||
|
/// Whether to stream the response
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
/// Whether to return logprobs
|
||||||
|
#[serde(default)]
|
||||||
|
pub return_logprob: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum InputIds {
|
||||||
|
Single(Vec<i32>),
|
||||||
|
Batch(Vec<Vec<i32>>),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||||
|
pub struct GenerateParameters {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub best_of: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub decoder_input_details: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub details: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub do_sample: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_new_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub return_full_text: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub truncate: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub typical_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub watermark: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
|
||||||
|
pub struct SamplingParams {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_new_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_k: Option<i32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub repetition_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<StringOrArray>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub ignore_eos: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub skip_special_tokens: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerationRequest for GenerateRequest {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_model(&self) -> Option<&str> {
|
||||||
|
// Generate requests typically don't have a model field
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_text_for_routing(&self) -> String {
|
||||||
|
// Check fields in priority order: text, prompt, inputs
|
||||||
|
if let Some(ref text) = self.text {
|
||||||
|
return text.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref prompt) = self.prompt {
|
||||||
|
return match prompt {
|
||||||
|
StringOrArray::String(s) => s.clone(),
|
||||||
|
StringOrArray::Array(v) => v.join(" "),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref input_ids) = self.input_ids {
|
||||||
|
return match input_ids {
|
||||||
|
InputIds::Single(ids) => ids
|
||||||
|
.iter()
|
||||||
|
.map(|&id| id.to_string())
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(" "),
|
||||||
|
InputIds::Batch(batches) => batches
|
||||||
|
.iter()
|
||||||
|
.flat_map(|batch| batch.iter().map(|&id| id.to_string()))
|
||||||
|
.collect::<Vec<String>>()
|
||||||
|
.join(" "),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// No text input found
|
||||||
|
String::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Helper Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum StringOrArray {
|
||||||
|
String(String),
|
||||||
|
Array(Vec<String>),
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Response Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "text_completion"
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<CompletionChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionChoice {
|
||||||
|
pub text: String,
|
||||||
|
pub index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<LogProbs>,
|
||||||
|
pub finish_reason: Option<String>, // "stop", "length", "content_filter", etc.
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct LogProbs {
|
||||||
|
pub tokens: Vec<String>,
|
||||||
|
pub token_logprobs: Vec<Option<f32>>,
|
||||||
|
pub top_logprobs: Vec<Option<HashMap<String, f32>>>,
|
||||||
|
pub text_offset: Vec<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "chat.completion"
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
pub choices: Vec<ChatChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: ChatMessage,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
|
pub finish_reason: Option<String>, // "stop", "length", "tool_calls", "content_filter", "function_call"
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatLogProbs {
|
||||||
|
pub content: Option<Vec<ChatLogProbsContent>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatLogProbsContent {
|
||||||
|
pub token: String,
|
||||||
|
pub logprob: f32,
|
||||||
|
pub bytes: Option<Vec<u8>>,
|
||||||
|
pub top_logprobs: Vec<TopLogProb>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct TopLogProb {
|
||||||
|
pub token: String,
|
||||||
|
pub logprob: f32,
|
||||||
|
pub bytes: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct Usage {
|
||||||
|
pub prompt_tokens: u32,
|
||||||
|
pub completion_tokens: u32,
|
||||||
|
pub total_tokens: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completion_tokens_details: Option<CompletionTokensDetails>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionTokensDetails {
|
||||||
|
pub reasoning_tokens: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Streaming Response Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionStreamResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "text_completion"
|
||||||
|
pub created: u64,
|
||||||
|
pub choices: Vec<CompletionStreamChoice>,
|
||||||
|
pub model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct CompletionStreamChoice {
|
||||||
|
pub text: String,
|
||||||
|
pub index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<LogProbs>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatCompletionStreamResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String, // "chat.completion.chunk"
|
||||||
|
pub created: u64,
|
||||||
|
pub model: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub system_fingerprint: Option<String>,
|
||||||
|
pub choices: Vec<ChatStreamChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<Usage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatStreamChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub delta: ChatMessageDelta,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub logprobs: Option<ChatLogProbs>,
|
||||||
|
pub finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ChatMessageDelta {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub role: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function_call: Option<FunctionCallDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ToolCallDelta {
|
||||||
|
pub index: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub tool_type: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub function: Option<FunctionCallDelta>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct FunctionCallDelta {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub arguments: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Error Response Types =============
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ErrorResponse {
|
||||||
|
pub error: ErrorDetail,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
pub struct ErrorDetail {
|
||||||
|
pub message: String,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub error_type: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub param: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub code: Option<String>,
|
||||||
|
}
|
||||||
1002
sgl-router/src/pd_router.rs
Normal file
1002
sgl-router/src/pd_router.rs
Normal file
File diff suppressed because it is too large
Load Diff
245
sgl-router/src/pd_types.rs
Normal file
245
sgl-router/src/pd_types.rs
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
// Essential PDLB types extracted for PD routing
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum EngineType {
|
||||||
|
Prefill,
|
||||||
|
Decode,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct EngineInfo {
|
||||||
|
pub engine_type: EngineType,
|
||||||
|
pub url: String,
|
||||||
|
pub bootstrap_port: Option<u16>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EngineInfo {
|
||||||
|
pub fn new_prefill(url: String, bootstrap_port: Option<u16>) -> Self {
|
||||||
|
EngineInfo {
|
||||||
|
engine_type: EngineType::Prefill,
|
||||||
|
url,
|
||||||
|
bootstrap_port,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_decode(url: String) -> Self {
|
||||||
|
EngineInfo {
|
||||||
|
engine_type: EngineType::Decode,
|
||||||
|
url,
|
||||||
|
bootstrap_port: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn api_path(&self, api_path: &str) -> String {
|
||||||
|
if api_path.starts_with("/") {
|
||||||
|
format!("{}{}", self.url, api_path)
|
||||||
|
} else {
|
||||||
|
format!("{}/{}", self.url, api_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_hostname(&self) -> String {
|
||||||
|
// Simple hostname extraction without external dependencies
|
||||||
|
let url = self
|
||||||
|
.url
|
||||||
|
.trim_start_matches("http://")
|
||||||
|
.trim_start_matches("https://");
|
||||||
|
url.split(':').next().unwrap_or("localhost").to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PD-specific routing policies
|
||||||
|
#[derive(Debug, Clone, PartialEq)]
|
||||||
|
pub enum PDSelectionPolicy {
|
||||||
|
Random,
|
||||||
|
PowerOfTwo,
|
||||||
|
CacheAware {
|
||||||
|
cache_threshold: f32,
|
||||||
|
balance_abs_threshold: usize,
|
||||||
|
balance_rel_threshold: f32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// Bootstrap types from PDLB
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum SingleOrBatch<T> {
|
||||||
|
Single(T),
|
||||||
|
Batch(Vec<T>),
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type InputIds = SingleOrBatch<Vec<i32>>;
|
||||||
|
pub type InputText = SingleOrBatch<String>;
|
||||||
|
pub type BootstrapHost = SingleOrBatch<String>;
|
||||||
|
pub type BootstrapPort = SingleOrBatch<Option<u16>>;
|
||||||
|
pub type BootstrapRoom = SingleOrBatch<u64>;
|
||||||
|
|
||||||
|
// Bootstrap trait for request handling
|
||||||
|
pub trait Bootstrap: Send + Sync {
|
||||||
|
fn is_stream(&self) -> bool;
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, String>;
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
);
|
||||||
|
|
||||||
|
fn add_bootstrap_info(&mut self, prefill_info: &EngineInfo) -> Result<(), String> {
|
||||||
|
let batch_size = self.get_batch_size()?;
|
||||||
|
if let Some(batch_size) = batch_size {
|
||||||
|
self.set_bootstrap_info(
|
||||||
|
BootstrapHost::Batch(vec![prefill_info.get_hostname(); batch_size]),
|
||||||
|
BootstrapPort::Batch(vec![prefill_info.bootstrap_port; batch_size]),
|
||||||
|
// Use high-quality random numbers to minimize collision risk
|
||||||
|
BootstrapRoom::Batch(
|
||||||
|
(0..batch_size)
|
||||||
|
.map(|_| {
|
||||||
|
// Combine multiple sources of randomness for better distribution
|
||||||
|
let r1 = rand::random::<u64>();
|
||||||
|
let r2 = rand::random::<u64>();
|
||||||
|
r1.wrapping_add(r2.rotate_left(32))
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
self.set_bootstrap_info(
|
||||||
|
BootstrapHost::Single(prefill_info.get_hostname()),
|
||||||
|
BootstrapPort::Single(prefill_info.bootstrap_port),
|
||||||
|
BootstrapRoom::Single({
|
||||||
|
// Use high-quality random number for single requests too
|
||||||
|
let r1 = rand::random::<u64>();
|
||||||
|
let r2 = rand::random::<u64>();
|
||||||
|
r1.wrapping_add(r2.rotate_left(32))
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request types
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub struct GenerateReqInput {
|
||||||
|
pub text: Option<InputText>,
|
||||||
|
pub input_ids: Option<InputIds>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
pub bootstrap_host: Option<BootstrapHost>,
|
||||||
|
pub bootstrap_port: Option<BootstrapPort>,
|
||||||
|
pub bootstrap_room: Option<BootstrapRoom>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub other: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GenerateReqInput {
|
||||||
|
pub fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||||
|
if self.text.is_some() && self.input_ids.is_some() {
|
||||||
|
return Err("Both text and input_ids are present in the request".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check text batch
|
||||||
|
if let Some(InputText::Batch(texts)) = &self.text {
|
||||||
|
if texts.is_empty() {
|
||||||
|
return Err("Batch text array is empty".to_string());
|
||||||
|
}
|
||||||
|
if texts.len() > 10000 {
|
||||||
|
// Reasonable limit for production
|
||||||
|
return Err(format!(
|
||||||
|
"Batch size {} exceeds maximum allowed (10000)",
|
||||||
|
texts.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
return Ok(Some(texts.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check input_ids batch
|
||||||
|
if let Some(InputIds::Batch(ids)) = &self.input_ids {
|
||||||
|
if ids.is_empty() {
|
||||||
|
return Err("Batch input_ids array is empty".to_string());
|
||||||
|
}
|
||||||
|
if ids.len() > 10000 {
|
||||||
|
// Reasonable limit for production
|
||||||
|
return Err(format!(
|
||||||
|
"Batch size {} exceeds maximum allowed (10000)",
|
||||||
|
ids.len()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
// Validate each sequence is not empty
|
||||||
|
for (i, seq) in ids.iter().enumerate() {
|
||||||
|
if seq.is_empty() {
|
||||||
|
return Err(format!("Input sequence at index {} is empty", i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Ok(Some(ids.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Bootstrap for GenerateReqInput {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||||
|
self.get_batch_size()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
) {
|
||||||
|
self.bootstrap_host = Some(bootstrap_host);
|
||||||
|
self.bootstrap_port = Some(bootstrap_port);
|
||||||
|
self.bootstrap_room = Some(bootstrap_room);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ChatReqInput {
|
||||||
|
#[serde(default)]
|
||||||
|
pub stream: bool,
|
||||||
|
pub bootstrap_host: Option<BootstrapHost>,
|
||||||
|
pub bootstrap_port: Option<BootstrapPort>,
|
||||||
|
pub bootstrap_room: Option<BootstrapRoom>,
|
||||||
|
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub other: Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Bootstrap for ChatReqInput {
|
||||||
|
fn is_stream(&self) -> bool {
|
||||||
|
self.stream
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_batch_size(&self) -> Result<Option<usize>, String> {
|
||||||
|
// Check if 'n' parameter is present and > 1
|
||||||
|
if let Some(n_value) = self.other.get("n") {
|
||||||
|
if let Some(n) = n_value.as_u64() {
|
||||||
|
if n > 1 {
|
||||||
|
return Ok(Some(n as usize));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_bootstrap_info(
|
||||||
|
&mut self,
|
||||||
|
bootstrap_host: BootstrapHost,
|
||||||
|
bootstrap_port: BootstrapPort,
|
||||||
|
bootstrap_room: BootstrapRoom,
|
||||||
|
) {
|
||||||
|
self.bootstrap_host = Some(bootstrap_host);
|
||||||
|
self.bootstrap_port = Some(bootstrap_port);
|
||||||
|
self.bootstrap_room = Some(bootstrap_room);
|
||||||
|
}
|
||||||
|
}
|
||||||
264
sgl-router/src/request_adapter.rs
Normal file
264
sgl-router/src/request_adapter.rs
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
// Request adapter to bridge OpenAI API types with PD routing requirements
|
||||||
|
|
||||||
|
use crate::openai_api_types::{
|
||||||
|
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray,
|
||||||
|
};
|
||||||
|
use crate::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
/// Adapter trait to convert OpenAI requests to PD-compatible requests
|
||||||
|
pub trait ToPdRequest {
|
||||||
|
type Output: Bootstrap;
|
||||||
|
fn to_pd_request(self) -> Self::Output;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper macro to insert optional fields into a map
|
||||||
|
macro_rules! insert_if_some {
|
||||||
|
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||||
|
$(
|
||||||
|
if let Some(value) = $field {
|
||||||
|
$map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null));
|
||||||
|
}
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper macro for simple value insertions
|
||||||
|
macro_rules! insert_value {
|
||||||
|
($map:expr, $($field:expr => $key:expr),* $(,)?) => {
|
||||||
|
$(
|
||||||
|
$map.insert($key.to_string(), $field.into());
|
||||||
|
)*
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Generate Request Adapter =============
|
||||||
|
|
||||||
|
impl ToPdRequest for GenerateRequest {
|
||||||
|
type Output = GenerateReqInput;
|
||||||
|
|
||||||
|
fn to_pd_request(self) -> Self::Output {
|
||||||
|
// Build the other fields first
|
||||||
|
let mut other = serde_json::Map::new();
|
||||||
|
|
||||||
|
// Handle text input - check in priority order: text (SGLang), prompt (OpenAI)
|
||||||
|
let (text, input_ids) = if let Some(text_str) = self.text {
|
||||||
|
// SGLang native format
|
||||||
|
(Some(SingleOrBatch::Single(text_str)), None)
|
||||||
|
} else if let Some(prompt) = self.prompt {
|
||||||
|
// OpenAI style prompt
|
||||||
|
let text = match prompt {
|
||||||
|
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||||
|
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||||
|
};
|
||||||
|
(text, None)
|
||||||
|
} else if let Some(ids) = self.input_ids {
|
||||||
|
// Input IDs case
|
||||||
|
let input_ids = match ids {
|
||||||
|
crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)),
|
||||||
|
crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)),
|
||||||
|
};
|
||||||
|
(None, input_ids)
|
||||||
|
} else {
|
||||||
|
// No input provided
|
||||||
|
(None, None)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add parameters to other - handle both old and new style
|
||||||
|
if let Some(params) = self.parameters {
|
||||||
|
// For generate endpoint, extract max_new_tokens to top level if present
|
||||||
|
let mut params_value = serde_json::to_value(¶ms).unwrap_or(Value::Null);
|
||||||
|
if let Value::Object(ref mut params_map) = params_value {
|
||||||
|
// Move max_new_tokens to top level if it exists
|
||||||
|
if let Some(max_new_tokens) = params_map.remove("max_new_tokens") {
|
||||||
|
other.insert("max_new_tokens".to_string(), max_new_tokens);
|
||||||
|
}
|
||||||
|
// Move temperature to top level if it exists
|
||||||
|
if let Some(temperature) = params_map.remove("temperature") {
|
||||||
|
other.insert("temperature".to_string(), temperature);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Only add parameters if there are remaining fields
|
||||||
|
if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty())
|
||||||
|
{
|
||||||
|
other.insert("parameters".to_string(), params_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add sampling_params if present
|
||||||
|
if let Some(sampling_params) = self.sampling_params {
|
||||||
|
let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null);
|
||||||
|
if !params_value.is_null() {
|
||||||
|
// Extract commonly used fields to top level
|
||||||
|
if let Value::Object(ref params_map) = params_value {
|
||||||
|
if let Some(max_new_tokens) = params_map.get("max_new_tokens") {
|
||||||
|
other.insert("max_new_tokens".to_string(), max_new_tokens.clone());
|
||||||
|
}
|
||||||
|
if let Some(temperature) = params_map.get("temperature") {
|
||||||
|
other.insert("temperature".to_string(), temperature.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
other.insert("sampling_params".to_string(), params_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add other fields
|
||||||
|
insert_value!(other,
|
||||||
|
self.stream => "stream",
|
||||||
|
self.return_logprob => "return_logprob"
|
||||||
|
);
|
||||||
|
|
||||||
|
GenerateReqInput {
|
||||||
|
text,
|
||||||
|
input_ids,
|
||||||
|
stream: self.stream,
|
||||||
|
bootstrap_host: None,
|
||||||
|
bootstrap_port: None,
|
||||||
|
bootstrap_room: None,
|
||||||
|
other: Value::Object(other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Completion Request Adapter =============
|
||||||
|
|
||||||
|
impl ToPdRequest for CompletionRequest {
|
||||||
|
type Output = GenerateReqInput;
|
||||||
|
|
||||||
|
fn to_pd_request(self) -> Self::Output {
|
||||||
|
// Convert CompletionRequest to GenerateReqInput
|
||||||
|
let text = match self.prompt {
|
||||||
|
StringOrArray::String(s) => Some(SingleOrBatch::Single(s)),
|
||||||
|
StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Map OpenAI parameters to generate parameters
|
||||||
|
let mut other = serde_json::Map::new();
|
||||||
|
|
||||||
|
// Create parameters object
|
||||||
|
let mut params = serde_json::Map::new();
|
||||||
|
|
||||||
|
// Map OpenAI fields to internal parameter names
|
||||||
|
insert_if_some!(params,
|
||||||
|
self.max_tokens => "max_new_tokens",
|
||||||
|
self.temperature => "temperature",
|
||||||
|
self.top_p => "top_p",
|
||||||
|
self.n => "best_of",
|
||||||
|
self.logprobs => "top_n_tokens",
|
||||||
|
self.seed => "seed"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Special handling for fields that need transformation
|
||||||
|
if let Some(presence_penalty) = self.presence_penalty {
|
||||||
|
params.insert(
|
||||||
|
"repetition_penalty".to_string(),
|
||||||
|
(1.0 + presence_penalty).into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(stop) = self.stop {
|
||||||
|
let stop_sequences = match stop {
|
||||||
|
StringOrArray::String(s) => vec![s],
|
||||||
|
StringOrArray::Array(v) => v,
|
||||||
|
};
|
||||||
|
params.insert("stop".to_string(), stop_sequences.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.echo {
|
||||||
|
params.insert("return_full_text".to_string(), true.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
other.insert("parameters".to_string(), Value::Object(params));
|
||||||
|
|
||||||
|
// Store original model and stream flag
|
||||||
|
insert_value!(other,
|
||||||
|
self.model => "model",
|
||||||
|
self.stream => "stream"
|
||||||
|
);
|
||||||
|
|
||||||
|
GenerateReqInput {
|
||||||
|
text,
|
||||||
|
input_ids: None,
|
||||||
|
stream: self.stream,
|
||||||
|
bootstrap_host: None,
|
||||||
|
bootstrap_port: None,
|
||||||
|
bootstrap_room: None,
|
||||||
|
other: Value::Object(other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Chat Completion Request Adapter =============
|
||||||
|
|
||||||
|
impl ToPdRequest for ChatCompletionRequest {
|
||||||
|
type Output = ChatReqInput;
|
||||||
|
|
||||||
|
fn to_pd_request(self) -> Self::Output {
|
||||||
|
let mut other = serde_json::Map::new();
|
||||||
|
|
||||||
|
// Add required fields
|
||||||
|
insert_if_some!(other,
|
||||||
|
Some(&self.messages) => "messages"
|
||||||
|
);
|
||||||
|
|
||||||
|
insert_value!(other,
|
||||||
|
self.model => "model",
|
||||||
|
self.stream => "stream"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Add all optional fields
|
||||||
|
insert_if_some!(other,
|
||||||
|
self.temperature => "temperature",
|
||||||
|
self.top_p => "top_p",
|
||||||
|
self.n => "n",
|
||||||
|
self.stop => "stop",
|
||||||
|
self.max_tokens => "max_tokens",
|
||||||
|
self.max_completion_tokens => "max_completion_tokens",
|
||||||
|
self.presence_penalty => "presence_penalty",
|
||||||
|
self.frequency_penalty => "frequency_penalty",
|
||||||
|
self.logit_bias => "logit_bias",
|
||||||
|
self.user => "user",
|
||||||
|
self.seed => "seed",
|
||||||
|
self.top_logprobs => "top_logprobs",
|
||||||
|
self.response_format => "response_format",
|
||||||
|
self.tools => "tools",
|
||||||
|
self.tool_choice => "tool_choice",
|
||||||
|
self.parallel_tool_calls => "parallel_tool_calls",
|
||||||
|
self.functions => "functions",
|
||||||
|
self.function_call => "function_call"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Handle boolean logprobs flag
|
||||||
|
if self.logprobs {
|
||||||
|
other.insert("logprobs".to_string(), true.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
ChatReqInput {
|
||||||
|
stream: self.stream,
|
||||||
|
bootstrap_host: None,
|
||||||
|
bootstrap_port: None,
|
||||||
|
bootstrap_room: None,
|
||||||
|
other: Value::Object(other),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============= Direct routing support for regular router =============
|
||||||
|
|
||||||
|
/// Extension trait for routing without PD conversion
|
||||||
|
pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone {
|
||||||
|
/// Convert to JSON for sending to backend
|
||||||
|
fn to_json(&self) -> Result<Value, serde_json::Error> {
|
||||||
|
serde_json::to_value(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert to bytes for legacy routing
|
||||||
|
fn to_bytes(&self) -> Result<bytes::Bytes, serde_json::Error> {
|
||||||
|
let json = serde_json::to_vec(self)?;
|
||||||
|
Ok(bytes::Bytes::from(json))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RouteableRequest for GenerateRequest {}
|
||||||
|
impl RouteableRequest for CompletionRequest {}
|
||||||
|
impl RouteableRequest for ChatCompletionRequest {}
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
|
use crate::pd_router::PDRouter;
|
||||||
|
use crate::pd_types::PDSelectionPolicy;
|
||||||
use crate::tree::Tree;
|
use crate::tree::Tree;
|
||||||
use ::metrics::{counter, gauge, histogram};
|
use ::metrics::{counter, gauge, histogram};
|
||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
use actix_web::{HttpRequest, HttpResponse};
|
||||||
use bytes::Bytes;
|
|
||||||
use futures_util::{StreamExt, TryStreamExt};
|
use futures_util::{StreamExt, TryStreamExt};
|
||||||
use serde_json::Value;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
use std::sync::atomic::AtomicUsize;
|
use std::sync::atomic::AtomicUsize;
|
||||||
@@ -15,7 +15,7 @@ use std::time::Instant;
|
|||||||
use tokio;
|
use tokio;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
pub fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> {
|
||||||
req.headers()
|
req.headers()
|
||||||
.iter()
|
.iter()
|
||||||
.filter_map(|(name, value)| {
|
.filter_map(|(name, value)| {
|
||||||
@@ -40,6 +40,9 @@ pub enum Router {
|
|||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
},
|
},
|
||||||
|
PrefillDecode {
|
||||||
|
pd_router: Arc<PDRouter>,
|
||||||
|
},
|
||||||
CacheAware {
|
CacheAware {
|
||||||
/*
|
/*
|
||||||
Cache-Aware Load Balancing Router
|
Cache-Aware Load Balancing Router
|
||||||
@@ -133,6 +136,13 @@ pub enum PolicyConfig {
|
|||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
},
|
},
|
||||||
|
PrefillDecodeConfig {
|
||||||
|
selection_policy: PDSelectionPolicy,
|
||||||
|
prefill_urls: Vec<(String, Option<u16>)>, // (url, bootstrap_port)
|
||||||
|
decode_urls: Vec<String>,
|
||||||
|
timeout_secs: u64,
|
||||||
|
interval_secs: u64,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
@@ -155,10 +165,24 @@ impl Router {
|
|||||||
interval_secs,
|
interval_secs,
|
||||||
..
|
..
|
||||||
} => (*timeout_secs, *interval_secs),
|
} => (*timeout_secs, *interval_secs),
|
||||||
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
|
timeout_secs,
|
||||||
|
interval_secs,
|
||||||
|
..
|
||||||
|
} => (*timeout_secs, *interval_secs),
|
||||||
};
|
};
|
||||||
|
|
||||||
// Wait until all workers are healthy
|
// For PrefillDecode, we need to handle workers differently
|
||||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
match &policy_config {
|
||||||
|
PolicyConfig::PrefillDecodeConfig { .. } => {
|
||||||
|
// PD mode doesn't use the worker_urls parameter
|
||||||
|
// We'll validate PD workers separately
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Wait until all workers are healthy for regular modes
|
||||||
|
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create router based on policy...
|
// Create router based on policy...
|
||||||
Ok(match policy_config {
|
Ok(match policy_config {
|
||||||
@@ -226,7 +250,7 @@ impl Router {
|
|||||||
});
|
});
|
||||||
|
|
||||||
for url in &worker_urls {
|
for url in &worker_urls {
|
||||||
tree.lock().unwrap().insert(&"".to_string(), url);
|
tree.lock().unwrap().insert("", url);
|
||||||
}
|
}
|
||||||
|
|
||||||
Router::CacheAware {
|
Router::CacheAware {
|
||||||
@@ -242,6 +266,26 @@ impl Router {
|
|||||||
_eviction_thread: Some(eviction_thread),
|
_eviction_thread: Some(eviction_thread),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
|
selection_policy,
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
timeout_secs,
|
||||||
|
interval_secs,
|
||||||
|
} => {
|
||||||
|
// Create PDRouter instance
|
||||||
|
let pd_router = PDRouter::new(
|
||||||
|
prefill_urls,
|
||||||
|
decode_urls,
|
||||||
|
selection_policy,
|
||||||
|
timeout_secs,
|
||||||
|
interval_secs,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Router::PrefillDecode {
|
||||||
|
pd_router: Arc::new(pd_router),
|
||||||
|
}
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -251,16 +295,23 @@ impl Router {
|
|||||||
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
Router::RoundRobin { worker_urls, .. } => Arc::clone(worker_urls),
|
||||||
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
Router::Random { worker_urls, .. } => Arc::clone(worker_urls),
|
||||||
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
Router::CacheAware { worker_urls, .. } => Arc::clone(worker_urls),
|
||||||
|
Router::PrefillDecode { .. } => {
|
||||||
|
// For PD mode, return empty list since we manage workers differently
|
||||||
|
Arc::new(RwLock::new(Vec::new()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn wait_for_healthy_workers(
|
pub fn wait_for_healthy_workers(
|
||||||
worker_urls: &[String],
|
worker_urls: &[String],
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let sync_client = reqwest::blocking::Client::new();
|
let sync_client = reqwest::blocking::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(timeout_secs))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||||
@@ -323,10 +374,14 @@ impl Router {
|
|||||||
Ok(worker_urls.read().unwrap()[0].clone())
|
Ok(worker_urls.read().unwrap()[0].clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Router::PrefillDecode { .. } => {
|
||||||
|
// For PD mode, we don't need this method as routing is handled by PDRouter
|
||||||
|
Err("PrefillDecode mode doesn't use select_first_worker".to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_request(
|
pub async fn send_request(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
worker_url: &str,
|
worker_url: &str,
|
||||||
@@ -339,7 +394,11 @@ impl Router {
|
|||||||
// Copy all headers from original request except for /health because it does not need authorization
|
// Copy all headers from original request except for /health because it does not need authorization
|
||||||
if route != "/health" {
|
if route != "/health" {
|
||||||
for (name, value) in copy_request_headers(req) {
|
for (name, value) in copy_request_headers(req) {
|
||||||
request_builder = request_builder.header(name, value);
|
// Skip Content-Type and Content-Length as .json() sets them
|
||||||
|
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||||
|
{
|
||||||
|
request_builder = request_builder.header(name, value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -433,50 +492,193 @@ impl Router {
|
|||||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_text_from_request(&self, body: &Bytes, route: &str) -> String {
|
pub async fn route_to_all(
|
||||||
// Convert body to JSON
|
&self,
|
||||||
let json: Value = match serde_json::from_slice(body) {
|
client: &reqwest::Client,
|
||||||
Ok(j) => j,
|
route: &str,
|
||||||
Err(_) => {
|
req: &HttpRequest,
|
||||||
warn!("Failed to parse JSON from request body.");
|
) -> HttpResponse {
|
||||||
return String::new();
|
// Get all worker URLs based on router type
|
||||||
|
let worker_urls = match self {
|
||||||
|
Router::PrefillDecode { .. } => {
|
||||||
|
// For PD mode, route_to_all is not supported directly
|
||||||
|
// It should be handled by PDRouter if needed
|
||||||
|
return HttpResponse::NotImplemented()
|
||||||
|
.body("route_to_all not implemented for PrefillDecode mode");
|
||||||
}
|
}
|
||||||
|
_ => self.get_worker_urls().read().unwrap().clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
match route {
|
// Send requests to all workers concurrently
|
||||||
"/generate" => {
|
let mut tasks = Vec::new();
|
||||||
// For /generate, always use the "text" field.
|
for worker_url in &worker_urls {
|
||||||
match json.get("text").and_then(Value::as_str) {
|
let mut request_builder = client.post(format!("{}{}", worker_url, route));
|
||||||
Some(text) => text.to_string(),
|
|
||||||
None => {
|
// Copy headers from original request
|
||||||
warn!("No 'text' field found in request body for route /generate.");
|
for (name, value) in copy_request_headers(req) {
|
||||||
String::new()
|
request_builder = request_builder.header(name, value);
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
"/v1/chat/completions" | "/v1/completions" => {
|
|
||||||
// For these routes, try "messages", then "prompt", then "text".
|
tasks.push(request_builder.send());
|
||||||
if let Some(messages) = json.get("messages") {
|
}
|
||||||
serde_json::to_string(messages).unwrap_or_default()
|
|
||||||
} else if let Some(prompt) = json.get("prompt").and_then(Value::as_str) {
|
// Wait for all responses
|
||||||
prompt.to_string()
|
let results = futures_util::future::join_all(tasks).await;
|
||||||
} else {
|
|
||||||
warn!("Failed to find 'messages', 'prompt' in request body.");
|
// Check if all succeeded
|
||||||
String::new()
|
let all_success = results.iter().all(|r| {
|
||||||
}
|
r.as_ref()
|
||||||
|
.map(|res| res.status().is_success())
|
||||||
|
.unwrap_or(false)
|
||||||
|
});
|
||||||
|
|
||||||
|
if all_success {
|
||||||
|
HttpResponse::Ok().body("Operation completed on all servers")
|
||||||
|
} else {
|
||||||
|
HttpResponse::InternalServerError().body("Operation failed on one or more servers")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_all_loads(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
_req: &HttpRequest,
|
||||||
|
) -> HttpResponse {
|
||||||
|
// For PD mode, delegate to PDRouter
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
return pd_router.get_loads(client).await;
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unknown route: {} - defaulting to fallback string", route);
|
// For non-PD routers, handle normally
|
||||||
String::new()
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let urls = self.get_worker_urls().read().unwrap().clone();
|
||||||
|
let prefill_urls: Vec<String> = Vec::new();
|
||||||
|
let decode_urls = urls;
|
||||||
|
|
||||||
|
// Collect loads from all servers
|
||||||
|
let mut prefill_loads = Vec::new();
|
||||||
|
let mut decode_loads = Vec::new();
|
||||||
|
|
||||||
|
// Get prefill loads
|
||||||
|
for url in &prefill_urls {
|
||||||
|
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||||
|
prefill_loads.push(serde_json::json!({
|
||||||
|
"engine": format!("(Prefill@{})", url),
|
||||||
|
"load": load as i64
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get decode loads
|
||||||
|
for url in &decode_urls {
|
||||||
|
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||||
|
decode_loads.push(serde_json::json!({
|
||||||
|
"engine": format!("(Decode@{})", url),
|
||||||
|
"load": load as i64
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
HttpResponse::Ok().json(serde_json::json!({
|
||||||
|
"prefill": prefill_loads,
|
||||||
|
"decode": decode_loads
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// New method to route typed requests directly
|
||||||
|
pub async fn route_typed_request<
|
||||||
|
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||||
|
>(
|
||||||
|
&self,
|
||||||
|
client: &reqwest::Client,
|
||||||
|
req: &HttpRequest,
|
||||||
|
typed_req: &T,
|
||||||
|
route: &str,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { .. } => HttpResponse::InternalServerError()
|
||||||
|
.body("PD routing should use specialized typed handlers"),
|
||||||
|
_ => {
|
||||||
|
// Handle retries like the original implementation
|
||||||
|
let start = Instant::now();
|
||||||
|
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||||
|
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||||
|
let mut total_retries = 0;
|
||||||
|
|
||||||
|
while total_retries < MAX_TOTAL_RETRIES {
|
||||||
|
// Extract routing text directly from typed request
|
||||||
|
let text = typed_req.extract_text_for_routing();
|
||||||
|
let is_stream = typed_req.is_stream();
|
||||||
|
|
||||||
|
// Select worker based on text
|
||||||
|
let worker_url = self.select_generate_worker_from_text(&text);
|
||||||
|
let mut request_retries = 0;
|
||||||
|
|
||||||
|
// Try the same worker multiple times
|
||||||
|
while request_retries < MAX_REQUEST_RETRIES {
|
||||||
|
if total_retries >= 1 {
|
||||||
|
info!("Retrying request after {} failed attempts", total_retries);
|
||||||
|
counter!("sgl_router_retries_total", "route" => route.to_string())
|
||||||
|
.increment(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send typed request directly
|
||||||
|
let response = self
|
||||||
|
.send_typed_request(
|
||||||
|
client,
|
||||||
|
req,
|
||||||
|
typed_req,
|
||||||
|
route,
|
||||||
|
&worker_url,
|
||||||
|
is_stream,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
if response.status().is_success() {
|
||||||
|
let duration = start.elapsed();
|
||||||
|
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||||
|
.record(duration.as_secs_f64());
|
||||||
|
return response;
|
||||||
|
} else {
|
||||||
|
// if the worker is healthy, it means the request is bad, so return the error response
|
||||||
|
let health_response =
|
||||||
|
self.send_request(client, &worker_url, "/health", req).await;
|
||||||
|
if health_response.status().is_success() {
|
||||||
|
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||||
|
.increment(1);
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
"Generate request to {} failed (attempt {}/{})",
|
||||||
|
worker_url,
|
||||||
|
request_retries + 1,
|
||||||
|
MAX_REQUEST_RETRIES
|
||||||
|
);
|
||||||
|
|
||||||
|
request_retries += 1;
|
||||||
|
total_retries += 1;
|
||||||
|
|
||||||
|
if request_retries == MAX_REQUEST_RETRIES {
|
||||||
|
warn!("Removing failed worker: {}", worker_url);
|
||||||
|
self.remove_worker(&worker_url);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
||||||
|
.increment(1);
|
||||||
|
HttpResponse::InternalServerError().body("All retry attempts failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: return Result<String, String> instead of panicking
|
// Helper method to select worker from text
|
||||||
fn select_generate_worker(&self, body: &Bytes, route: &str) -> String {
|
fn select_generate_worker_from_text(&self, text: &str) -> String {
|
||||||
let text = self.get_text_from_request(&body, route);
|
match self {
|
||||||
|
|
||||||
let worker_url = match self {
|
|
||||||
Router::RoundRobin {
|
Router::RoundRobin {
|
||||||
worker_urls,
|
worker_urls,
|
||||||
current_index,
|
current_index,
|
||||||
@@ -506,8 +708,6 @@ impl Router {
|
|||||||
balance_rel_threshold,
|
balance_rel_threshold,
|
||||||
..
|
..
|
||||||
} => {
|
} => {
|
||||||
// TODO: delay scheduling if cache hit rate is high because it may cause imbalance. prioritize low hit rate ones
|
|
||||||
|
|
||||||
let tree = tree.lock().unwrap();
|
let tree = tree.lock().unwrap();
|
||||||
let mut running_queue = running_queue.lock().unwrap();
|
let mut running_queue = running_queue.lock().unwrap();
|
||||||
|
|
||||||
@@ -572,35 +772,48 @@ impl Router {
|
|||||||
|
|
||||||
selected_url
|
selected_url
|
||||||
}
|
}
|
||||||
};
|
Router::PrefillDecode { .. } => {
|
||||||
|
// For PD mode, we don't use this method
|
||||||
worker_url
|
return "PD_MODE_ERROR".to_string();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_generate_request(
|
// Send typed request directly without conversion
|
||||||
|
async fn send_typed_request<T: serde::Serialize>(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
client: &reqwest::Client,
|
||||||
req: &HttpRequest,
|
req: &HttpRequest,
|
||||||
body: &Bytes,
|
typed_req: &T,
|
||||||
route: &str,
|
route: &str,
|
||||||
worker_url: &str,
|
worker_url: &str,
|
||||||
|
is_stream: bool,
|
||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
let is_stream = serde_json::from_slice::<serde_json::Value>(&body)
|
let start = Instant::now();
|
||||||
.map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false))
|
|
||||||
.unwrap_or(false);
|
// Debug: Log what we're sending
|
||||||
|
if let Ok(json_str) = serde_json::to_string_pretty(typed_req) {
|
||||||
|
debug!("Sending request to {}: {}", route, json_str);
|
||||||
|
}
|
||||||
|
|
||||||
let mut request_builder = client
|
let mut request_builder = client
|
||||||
.post(format!("{}{}", worker_url, route))
|
.post(format!("{}{}", worker_url, route))
|
||||||
.body(body.to_vec());
|
.json(typed_req); // Use json() directly with typed request
|
||||||
|
|
||||||
// Copy all headers from original request
|
// Copy all headers from original request
|
||||||
for (name, value) in copy_request_headers(req) {
|
for (name, value) in copy_request_headers(req) {
|
||||||
request_builder = request_builder.header(name, value);
|
// Skip Content-Type and Content-Length as .json() sets them
|
||||||
|
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" {
|
||||||
|
request_builder = request_builder.header(&name, &value);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let res = match request_builder.send().await {
|
let res = match request_builder.send().await {
|
||||||
Ok(res) => res,
|
Ok(res) => res,
|
||||||
Err(_) => return HttpResponse::InternalServerError().finish(),
|
Err(e) => {
|
||||||
|
error!("Failed to send request to {}: {}", worker_url, e);
|
||||||
|
return HttpResponse::InternalServerError().body(format!("Request failed: {}", e));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
let status = actix_web::http::StatusCode::from_u16(res.status().as_u16())
|
||||||
@@ -625,6 +838,12 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Record metrics
|
||||||
|
let duration = start.elapsed();
|
||||||
|
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string())
|
||||||
|
.record(duration.as_secs_f64());
|
||||||
|
counter!("sgl_router_requests_total", "route" => route.to_string()).increment(1);
|
||||||
|
|
||||||
response
|
response
|
||||||
} else if let Router::CacheAware { running_queue, .. } = self {
|
} else if let Router::CacheAware { running_queue, .. } = self {
|
||||||
let running_queue = Arc::clone(running_queue);
|
let running_queue = Arc::clone(running_queue);
|
||||||
@@ -660,70 +879,6 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn route_generate_request(
|
|
||||||
&self,
|
|
||||||
client: &reqwest::Client,
|
|
||||||
req: &HttpRequest,
|
|
||||||
body: &Bytes,
|
|
||||||
route: &str,
|
|
||||||
) -> HttpResponse {
|
|
||||||
let start = Instant::now();
|
|
||||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
|
||||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
|
||||||
let mut total_retries = 0;
|
|
||||||
|
|
||||||
while total_retries < MAX_TOTAL_RETRIES {
|
|
||||||
let worker_url = self.select_generate_worker(body, route);
|
|
||||||
let mut request_retries = 0;
|
|
||||||
|
|
||||||
// Try the same worker multiple times
|
|
||||||
while request_retries < MAX_REQUEST_RETRIES {
|
|
||||||
if total_retries >= 1 {
|
|
||||||
info!("Retrying request after {} failed attempts", total_retries);
|
|
||||||
counter!("sgl_router_retries_total", "route" => route.to_string()).increment(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = self
|
|
||||||
.send_generate_request(client, req, body, route, &worker_url)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
if response.status().is_success() {
|
|
||||||
let duration = start.elapsed();
|
|
||||||
histogram!("sgl_router_generate_duration_seconds", "route" => route.to_string()).record(duration.as_secs_f64());
|
|
||||||
return response;
|
|
||||||
} else {
|
|
||||||
// if the worker is healthy, it means the request is bad, so return the error response
|
|
||||||
let health_response =
|
|
||||||
self.send_request(client, &worker_url, "/health", req).await;
|
|
||||||
if health_response.status().is_success() {
|
|
||||||
counter!("sgl_router_request_errors_total", "route" => route.to_string())
|
|
||||||
.increment(1);
|
|
||||||
return response;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
warn!(
|
|
||||||
"Generate request to {} failed (attempt {}/{})",
|
|
||||||
worker_url,
|
|
||||||
request_retries + 1,
|
|
||||||
MAX_REQUEST_RETRIES
|
|
||||||
);
|
|
||||||
|
|
||||||
request_retries += 1;
|
|
||||||
total_retries += 1;
|
|
||||||
|
|
||||||
if request_retries == MAX_REQUEST_RETRIES {
|
|
||||||
warn!("Removing failed worker: {}", worker_url);
|
|
||||||
self.remove_worker(&worker_url);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
counter!("sgl_router_request_errors_total", "route" => route.to_string()).increment(1);
|
|
||||||
HttpResponse::InternalServerError().body("All retry attempts failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
pub async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||||
let (timeout_secs, interval_secs) = match self {
|
let (timeout_secs, interval_secs) = match self {
|
||||||
Router::Random {
|
Router::Random {
|
||||||
@@ -741,10 +896,17 @@ impl Router {
|
|||||||
interval_secs,
|
interval_secs,
|
||||||
..
|
..
|
||||||
} => (*timeout_secs, *interval_secs),
|
} => (*timeout_secs, *interval_secs),
|
||||||
|
Router::PrefillDecode { .. } => {
|
||||||
|
// For PD mode, we don't support adding workers via this method
|
||||||
|
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
let client = reqwest::Client::new();
|
let client = reqwest::Client::builder()
|
||||||
|
.timeout(Duration::from_secs(timeout_secs))
|
||||||
|
.build()
|
||||||
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
if start_time.elapsed() > Duration::from_secs(timeout_secs) {
|
||||||
@@ -774,6 +936,9 @@ impl Router {
|
|||||||
urls.push(worker_url.to_string());
|
urls.push(worker_url.to_string());
|
||||||
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
gauge!("sgl_router_active_workers").set(urls.len() as f64);
|
||||||
}
|
}
|
||||||
|
Router::PrefillDecode { .. } => {
|
||||||
|
return Err("Adding workers to PrefillDecode router not supported via add_worker. Use dedicated PD management methods.".to_string());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If cache aware, initialize the queues for the new worker
|
// If cache aware, initialize the queues for the new worker
|
||||||
@@ -797,7 +962,7 @@ impl Router {
|
|||||||
.insert(worker_url.to_string(), 0);
|
.insert(worker_url.to_string(), 0);
|
||||||
|
|
||||||
// Add worker to tree
|
// Add worker to tree
|
||||||
tree.lock().unwrap().insert(&"".to_string(), &worker_url);
|
tree.lock().unwrap().insert("", worker_url);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||||
@@ -850,6 +1015,10 @@ impl Router {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Router::PrefillDecode { .. } => {
|
||||||
|
warn!("Removing workers from PrefillDecode router not supported via remove_worker. Use dedicated PD management methods.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if cache aware, remove the worker from the tree
|
// if cache aware, remove the worker from the tree
|
||||||
@@ -875,4 +1044,133 @@ impl Router {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||||
|
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||||
|
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||||
|
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||||
|
Ok(data) => data
|
||||||
|
.get("load")
|
||||||
|
.and_then(|v| v.as_i64())
|
||||||
|
.map(|v| v as isize),
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to parse load response from {}: {}", worker_url, e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to read load response from {}: {}", worker_url, e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Ok(res) => {
|
||||||
|
debug!(
|
||||||
|
"Worker {} returned non-success status: {}",
|
||||||
|
worker_url,
|
||||||
|
res.status()
|
||||||
|
);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to get load from {}: {}", worker_url, e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PD-specific wrapper methods that delegate to PDRouter
|
||||||
|
pub async fn route_pd_health_generate(
|
||||||
|
&self,
|
||||||
|
_client: &reqwest::Client,
|
||||||
|
_req: &HttpRequest,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router.health_generate(&pd_router.http_client).await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route_pd_generate_typed(
|
||||||
|
&self,
|
||||||
|
_client: &reqwest::Client,
|
||||||
|
req: &HttpRequest,
|
||||||
|
typed_req: crate::pd_types::GenerateReqInput,
|
||||||
|
route: &str,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router
|
||||||
|
.route_generate(&pd_router.http_client, req, typed_req, route)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route_pd_chat_typed(
|
||||||
|
&self,
|
||||||
|
_client: &reqwest::Client,
|
||||||
|
req: &HttpRequest,
|
||||||
|
typed_req: crate::pd_types::ChatReqInput,
|
||||||
|
route: &str,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router
|
||||||
|
.route_chat(&pd_router.http_client, req, typed_req, route)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_pd_server_info(
|
||||||
|
&self,
|
||||||
|
_client: &reqwest::Client,
|
||||||
|
_req: &HttpRequest,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router.get_server_info(&pd_router.http_client).await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_pd_models(
|
||||||
|
&self,
|
||||||
|
_client: &reqwest::Client,
|
||||||
|
req: &HttpRequest,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router.get_models(&pd_router.http_client, req).await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn route_pd_flush_cache(&self, _client: &reqwest::Client) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router.flush_cache(&pd_router.http_client).await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_pd_model_info(
|
||||||
|
&self,
|
||||||
|
_client: &reqwest::Client,
|
||||||
|
req: &HttpRequest,
|
||||||
|
) -> HttpResponse {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => {
|
||||||
|
pd_router.get_model_info(&pd_router.http_client, req).await
|
||||||
|
}
|
||||||
|
_ => HttpResponse::InternalServerError().body("Not in PrefillDecode mode"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,13 @@
|
|||||||
use crate::logging::{self, LoggingConfig};
|
use crate::logging::{self, LoggingConfig};
|
||||||
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
use crate::prometheus::{self, PrometheusConfig};
|
use crate::prometheus::{self, PrometheusConfig};
|
||||||
|
use crate::request_adapter::ToPdRequest;
|
||||||
use crate::router::PolicyConfig;
|
use crate::router::PolicyConfig;
|
||||||
use crate::router::Router;
|
use crate::router::Router;
|
||||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||||
use actix_web::{
|
use actix_web::{
|
||||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||||
};
|
};
|
||||||
use bytes::Bytes;
|
|
||||||
use futures_util::StreamExt;
|
use futures_util::StreamExt;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
|
|||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
router: Arc<Router>,
|
router: Arc<Router>,
|
||||||
client: Client,
|
client: Client,
|
||||||
|
is_pd_mode: bool, // Add flag to track PD mode
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
@@ -28,9 +30,16 @@ impl AppState {
|
|||||||
client: Client,
|
client: Client,
|
||||||
policy_config: PolicyConfig,
|
policy_config: PolicyConfig,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
|
// Check if this is PD mode from policy config
|
||||||
|
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
|
||||||
|
|
||||||
// Create router based on policy
|
// Create router based on policy
|
||||||
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
||||||
Ok(Self { router, client })
|
Ok(Self {
|
||||||
|
router,
|
||||||
|
client,
|
||||||
|
is_pd_mode,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Custom error handler for JSON payload errors.
|
// Custom error handler for JSON payload errors.
|
||||||
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||||
error::ErrorPayloadTooLarge("Payload too large")
|
error!("JSON payload error: {:?}", err);
|
||||||
|
match &err {
|
||||||
|
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
||||||
|
error!(
|
||||||
|
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||||
|
length, limit
|
||||||
|
);
|
||||||
|
error::ErrorPayloadTooLarge(format!(
|
||||||
|
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||||
|
length, limit
|
||||||
|
))
|
||||||
|
}
|
||||||
|
error::JsonPayloadError::Overflow { limit } => {
|
||||||
|
error!("Payload overflow: exceeds limit of {} bytes", limit);
|
||||||
|
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
||||||
|
}
|
||||||
|
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/health")]
|
#[get("/health")]
|
||||||
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
|||||||
|
|
||||||
#[get("/health_generate")]
|
#[get("/health_generate")]
|
||||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
// Check if we're in PD mode
|
||||||
.route_to_first(&data.client, "/health_generate", &req)
|
if data.is_pd_mode {
|
||||||
.await
|
// For PD mode, check health on all servers
|
||||||
|
data.router
|
||||||
|
.route_pd_health_generate(&data.client, &req)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
// Regular mode
|
||||||
|
data.router
|
||||||
|
.route_to_first(&data.client, "/health_generate", &req)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_server_info")]
|
#[get("/get_server_info")]
|
||||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
if data.is_pd_mode {
|
||||||
.route_to_first(&data.client, "/get_server_info", &req)
|
// For PD mode, aggregate info from both prefill and decode servers
|
||||||
.await
|
data.router.get_pd_server_info(&data.client, &req).await
|
||||||
|
} else {
|
||||||
|
// Regular mode - return first server's info
|
||||||
|
data.router
|
||||||
|
.route_to_first(&data.client, "/get_server_info", &req)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/v1/models")]
|
#[get("/v1/models")]
|
||||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
if data.is_pd_mode {
|
||||||
.route_to_first(&data.client, "/v1/models", &req)
|
// For PD mode, return models from the first prefill server
|
||||||
.await
|
data.router.get_pd_models(&data.client, &req).await
|
||||||
|
} else {
|
||||||
|
// Regular mode
|
||||||
|
data.router
|
||||||
|
.route_to_first(&data.client, "/v1/models", &req)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[get("/get_model_info")]
|
#[get("/get_model_info")]
|
||||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
data.router
|
if data.is_pd_mode {
|
||||||
.route_to_first(&data.client, "/get_model_info", &req)
|
// For PD mode, get model info from the first prefill server
|
||||||
.await
|
data.router.get_pd_model_info(&data.client, &req).await
|
||||||
|
} else {
|
||||||
|
data.router
|
||||||
|
.route_to_first(&data.client, "/get_model_info", &req)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/generate")]
|
#[post("/generate")]
|
||||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
async fn generate(
|
||||||
data.router
|
req: HttpRequest,
|
||||||
.route_generate_request(&data.client, &req, &body, "/generate")
|
body: web::Json<GenerateRequest>,
|
||||||
.await
|
state: web::Data<AppState>,
|
||||||
|
) -> Result<HttpResponse, Error> {
|
||||||
|
let client = &state.client;
|
||||||
|
let router = &state.router;
|
||||||
|
|
||||||
|
// Use typed request directly for both PD and regular routing
|
||||||
|
if state.is_pd_mode {
|
||||||
|
// For PD mode, convert to PD request with bootstrap
|
||||||
|
let pd_request = body.into_inner().to_pd_request();
|
||||||
|
|
||||||
|
Ok(router
|
||||||
|
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
|
||||||
|
.await)
|
||||||
|
} else {
|
||||||
|
// For regular mode, use typed request directly
|
||||||
|
let request = body.into_inner();
|
||||||
|
Ok(router
|
||||||
|
.route_typed_request(&client, &req, &request, "/generate")
|
||||||
|
.await)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/chat/completions")]
|
#[post("/v1/chat/completions")]
|
||||||
async fn v1_chat_completions(
|
async fn v1_chat_completions(
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
body: Bytes,
|
body: web::Json<ChatCompletionRequest>,
|
||||||
data: web::Data<AppState>,
|
state: web::Data<AppState>,
|
||||||
) -> impl Responder {
|
) -> Result<HttpResponse, Error> {
|
||||||
data.router
|
let client = &state.client;
|
||||||
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
|
let router = &state.router;
|
||||||
.await
|
|
||||||
|
// Use typed request directly for both PD and regular routing
|
||||||
|
if state.is_pd_mode {
|
||||||
|
// For PD mode, convert to PD request with bootstrap
|
||||||
|
let pd_request = body.into_inner().to_pd_request();
|
||||||
|
|
||||||
|
Ok(router
|
||||||
|
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
|
||||||
|
.await)
|
||||||
|
} else {
|
||||||
|
// For regular mode, use typed request directly
|
||||||
|
let request = body.into_inner();
|
||||||
|
Ok(router
|
||||||
|
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
|
||||||
|
.await)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/v1/completions")]
|
#[post("/v1/completions")]
|
||||||
async fn v1_completions(
|
async fn v1_completions(
|
||||||
req: HttpRequest,
|
req: HttpRequest,
|
||||||
body: Bytes,
|
body: web::Json<CompletionRequest>,
|
||||||
data: web::Data<AppState>,
|
state: web::Data<AppState>,
|
||||||
) -> impl Responder {
|
) -> Result<HttpResponse, Error> {
|
||||||
data.router
|
let client = &state.client;
|
||||||
.route_generate_request(&data.client, &req, &body, "/v1/completions")
|
let router = &state.router;
|
||||||
.await
|
|
||||||
|
// Use typed request directly for both PD and regular routing
|
||||||
|
if state.is_pd_mode {
|
||||||
|
// For PD mode, convert to PD request with bootstrap
|
||||||
|
let pd_request = body.into_inner().to_pd_request();
|
||||||
|
|
||||||
|
Ok(router
|
||||||
|
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
|
||||||
|
.await)
|
||||||
|
} else {
|
||||||
|
// For regular mode, use typed request directly
|
||||||
|
let request = body.into_inner();
|
||||||
|
Ok(router
|
||||||
|
.route_typed_request(&client, &req, &request, "/v1/completions")
|
||||||
|
.await)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[post("/add_worker")]
|
#[post("/add_worker")]
|
||||||
@@ -153,6 +254,25 @@ async fn remove_worker(
|
|||||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[post("/flush_cache")]
|
||||||
|
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
|
if data.is_pd_mode {
|
||||||
|
// For PD mode, flush cache on both prefill and decode servers
|
||||||
|
data.router.route_pd_flush_cache(&data.client).await
|
||||||
|
} else {
|
||||||
|
// Route to all workers for cache flushing
|
||||||
|
data.router
|
||||||
|
.route_to_all(&data.client, "/flush_cache", &req)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[get("/get_loads")]
|
||||||
|
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||||
|
// Get loads from all workers
|
||||||
|
data.router.get_all_loads(&data.client, &req).await
|
||||||
|
}
|
||||||
|
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
pub host: String,
|
pub host: String,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
@@ -163,6 +283,7 @@ pub struct ServerConfig {
|
|||||||
pub log_dir: Option<String>,
|
pub log_dir: Option<String>,
|
||||||
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
||||||
pub prometheus_config: Option<PrometheusConfig>,
|
pub prometheus_config: Option<PrometheusConfig>,
|
||||||
|
pub request_timeout_secs: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||||
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
|
|
||||||
let client = Client::builder()
|
let client = Client::builder()
|
||||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||||
|
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
||||||
.build()
|
.build()
|
||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
|||||||
.service(add_worker)
|
.service(add_worker)
|
||||||
.service(remove_worker)
|
.service(remove_worker)
|
||||||
.service(list_workers)
|
.service(list_workers)
|
||||||
// Default handler for unmatched routes.
|
.service(flush_cache)
|
||||||
|
.service(get_loads)
|
||||||
.default_service(web::route().to(sink_handler))
|
.default_service(web::route().to(sink_handler))
|
||||||
})
|
})
|
||||||
.bind_auto_h2c((config.host, config.port))?
|
.bind_auto_h2c((config.host, config.port))?
|
||||||
|
|||||||
904
sgl-router/tests/test_pd_routing.rs
Normal file
904
sgl-router/tests/test_pd_routing.rs
Normal file
@@ -0,0 +1,904 @@
|
|||||||
|
//! Comprehensive tests for PrefillDecode (PD) routing functionality
|
||||||
|
//!
|
||||||
|
//! This test suite covers:
|
||||||
|
//! - Phase 1: Basic PD router creation and configuration
|
||||||
|
//! - Phase 2: Bootstrap injection and request handling
|
||||||
|
//! - Phase 3: Cache-aware selection (when implemented)
|
||||||
|
//!
|
||||||
|
//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type.
|
||||||
|
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test_pd_routing {
|
||||||
|
use rand::Rng;
|
||||||
|
use serde_json::json;
|
||||||
|
use sglang_router_rs::pd_types::{EngineInfo, EngineType, PDSelectionPolicy};
|
||||||
|
use sglang_router_rs::router::{PolicyConfig, Router};
|
||||||
|
|
||||||
|
// Test-only struct to help validate PD request parsing
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct PDRequest {
|
||||||
|
pub is_stream: bool,
|
||||||
|
pub batch_size: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PDRequest {
|
||||||
|
// Extract PD-relevant info from JSON for testing
|
||||||
|
pub fn from_json(json: &serde_json::Value) -> Self {
|
||||||
|
let is_stream = json
|
||||||
|
.get("stream")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
// Detect batch size from text or input_ids
|
||||||
|
let batch_size = if let Some(text) = json.get("text") {
|
||||||
|
text.as_array().map(|arr| arr.len())
|
||||||
|
} else if let Some(input_ids) = json.get("input_ids") {
|
||||||
|
input_ids.as_array().map(|arr| arr.len())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
PDRequest {
|
||||||
|
is_stream,
|
||||||
|
batch_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Phase 1: Basic PD Components and Router Creation
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_info_creation() {
|
||||||
|
// Test EngineInfo creation for prefill servers
|
||||||
|
let prefill_engine = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||||
|
match prefill_engine.engine_type {
|
||||||
|
EngineType::Prefill => (),
|
||||||
|
_ => panic!("Expected Prefill engine type"),
|
||||||
|
}
|
||||||
|
assert_eq!(prefill_engine.url, "http://prefill:8080");
|
||||||
|
assert_eq!(prefill_engine.bootstrap_port, Some(9000));
|
||||||
|
assert_eq!(prefill_engine.get_hostname(), "prefill");
|
||||||
|
|
||||||
|
// Test EngineInfo creation for decode servers
|
||||||
|
let decode_engine = EngineInfo::new_decode("http://decode:8080".to_string());
|
||||||
|
match decode_engine.engine_type {
|
||||||
|
EngineType::Decode => (),
|
||||||
|
_ => panic!("Expected Decode engine type"),
|
||||||
|
}
|
||||||
|
assert_eq!(decode_engine.url, "http://decode:8080");
|
||||||
|
assert_eq!(decode_engine.bootstrap_port, None);
|
||||||
|
assert_eq!(decode_engine.get_hostname(), "decode");
|
||||||
|
|
||||||
|
// Test API path generation
|
||||||
|
assert_eq!(
|
||||||
|
prefill_engine.api_path("/generate"),
|
||||||
|
"http://prefill:8080/generate"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
prefill_engine.api_path("health"),
|
||||||
|
"http://prefill:8080/health"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
decode_engine.api_path("/v1/chat/completions"),
|
||||||
|
"http://decode:8080/v1/chat/completions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_selection_policies() {
|
||||||
|
// Test all PD selection policy variants
|
||||||
|
// Note: These policies are only used when pd_disaggregated=true
|
||||||
|
let policies = vec![
|
||||||
|
PDSelectionPolicy::Random,
|
||||||
|
PDSelectionPolicy::PowerOfTwo,
|
||||||
|
PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for policy in policies {
|
||||||
|
// Verify each policy can be created and matched
|
||||||
|
match &policy {
|
||||||
|
PDSelectionPolicy::Random => {
|
||||||
|
assert!(matches!(policy, PDSelectionPolicy::Random));
|
||||||
|
}
|
||||||
|
PDSelectionPolicy::PowerOfTwo => {
|
||||||
|
assert!(matches!(policy, PDSelectionPolicy::PowerOfTwo));
|
||||||
|
}
|
||||||
|
PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold, ..
|
||||||
|
} => {
|
||||||
|
assert!(*cache_threshold >= 0.0 && *cache_threshold <= 1.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_router_configuration() {
|
||||||
|
// Test PrefillDecodeConfig creation with various policies
|
||||||
|
// This config is used when pd_disaggregated=true
|
||||||
|
let configs = vec![
|
||||||
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
|
selection_policy: PDSelectionPolicy::Random,
|
||||||
|
prefill_urls: vec![
|
||||||
|
("http://prefill1:8080".to_string(), Some(9000)),
|
||||||
|
("http://prefill2:8080".to_string(), None),
|
||||||
|
],
|
||||||
|
decode_urls: vec![
|
||||||
|
"http://decode1:8080".to_string(),
|
||||||
|
"http://decode2:8080".to_string(),
|
||||||
|
],
|
||||||
|
timeout_secs: 10,
|
||||||
|
interval_secs: 1,
|
||||||
|
},
|
||||||
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
|
selection_policy: PDSelectionPolicy::PowerOfTwo,
|
||||||
|
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||||
|
decode_urls: vec!["http://decode:8080".to_string()],
|
||||||
|
timeout_secs: 5,
|
||||||
|
interval_secs: 1,
|
||||||
|
},
|
||||||
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
|
selection_policy: PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold: 0.7,
|
||||||
|
balance_abs_threshold: 20,
|
||||||
|
balance_rel_threshold: 1.2,
|
||||||
|
},
|
||||||
|
prefill_urls: vec![
|
||||||
|
("http://p1:8080".to_string(), Some(9000)),
|
||||||
|
("http://p2:8080".to_string(), Some(9001)),
|
||||||
|
("http://p3:8080".to_string(), Some(9002)),
|
||||||
|
],
|
||||||
|
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||||
|
timeout_secs: 10,
|
||||||
|
interval_secs: 2,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
for config in configs {
|
||||||
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
|
let result = Router::new(vec![], config);
|
||||||
|
assert!(result.is_err());
|
||||||
|
let error_msg = result.unwrap_err();
|
||||||
|
// Error should be about health/timeout, not configuration
|
||||||
|
assert!(
|
||||||
|
error_msg.contains("healthy") || error_msg.contains("timeout"),
|
||||||
|
"Unexpected error: {}",
|
||||||
|
error_msg
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Phase 2: Bootstrap Injection and Request Handling
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_request_from_json() {
|
||||||
|
// Test PDRequest parsing from single text request
|
||||||
|
let single_json = json!({
|
||||||
|
"text": "Hello world",
|
||||||
|
"stream": false,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 100
|
||||||
|
});
|
||||||
|
|
||||||
|
let pd_req = PDRequest::from_json(&single_json);
|
||||||
|
assert!(!pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
|
||||||
|
// Test PDRequest parsing from batch text request
|
||||||
|
let batch_json = json!({
|
||||||
|
"text": ["Hello", "World", "Test"],
|
||||||
|
"stream": true,
|
||||||
|
"temperature": 0.5
|
||||||
|
});
|
||||||
|
|
||||||
|
let pd_req = PDRequest::from_json(&batch_json);
|
||||||
|
assert!(pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, Some(3));
|
||||||
|
|
||||||
|
// Test PDRequest parsing from input_ids request
|
||||||
|
let ids_json = json!({
|
||||||
|
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
let pd_req = PDRequest::from_json(&ids_json);
|
||||||
|
assert!(!pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, Some(2));
|
||||||
|
|
||||||
|
// Test PDRequest parsing from chat request
|
||||||
|
let chat_json = json!({
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": "Hello"}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
let pd_req = PDRequest::from_json(&chat_json);
|
||||||
|
assert!(pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bootstrap_injection_simulation() {
|
||||||
|
// Since we can't test the actual inject_bootstrap_fields function here
|
||||||
|
// (it's private in the router module), we'll test the expected behavior
|
||||||
|
|
||||||
|
// Simulate bootstrap injection for single request
|
||||||
|
let mut single_json = json!({
|
||||||
|
"text": "Hello world",
|
||||||
|
"stream": false,
|
||||||
|
"temperature": 0.7
|
||||||
|
});
|
||||||
|
|
||||||
|
// Simulate what inject_bootstrap_fields would do
|
||||||
|
let prefill_info = EngineInfo::new_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||||
|
single_json["bootstrap_host"] = json!(prefill_info.get_hostname());
|
||||||
|
single_json["bootstrap_port"] = json!(prefill_info.bootstrap_port);
|
||||||
|
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||||
|
|
||||||
|
// Verify bootstrap fields are added correctly
|
||||||
|
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||||
|
assert_eq!(single_json["bootstrap_port"], 9000);
|
||||||
|
assert!(single_json["bootstrap_room"].is_u64());
|
||||||
|
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||||
|
|
||||||
|
// Simulate bootstrap injection for batch request
|
||||||
|
let mut batch_json = json!({
|
||||||
|
"text": ["Hello", "World", "Test"],
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
let batch_size = 3;
|
||||||
|
batch_json["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||||
|
batch_json["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||||
|
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||||
|
|
||||||
|
// Verify batch bootstrap fields
|
||||||
|
assert!(batch_json["bootstrap_host"].is_array());
|
||||||
|
assert_eq!(
|
||||||
|
batch_json["bootstrap_host"].as_array().unwrap().len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
assert!(batch_json["bootstrap_port"].is_array());
|
||||||
|
assert!(batch_json["bootstrap_room"].is_array());
|
||||||
|
assert_eq!(batch_json["stream"], true); // Original field preserved
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_request_serialization() {
|
||||||
|
// Test that requests can be properly serialized and deserialized
|
||||||
|
let request = json!({
|
||||||
|
"text": "Test prompt",
|
||||||
|
"stream": false,
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 100,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"frequency_penalty": 0.5,
|
||||||
|
"bootstrap_host": "prefill1",
|
||||||
|
"bootstrap_port": 9000,
|
||||||
|
"bootstrap_room": 12345u64
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert to bytes (as would happen in the router)
|
||||||
|
let bytes = serde_json::to_vec(&request).unwrap();
|
||||||
|
|
||||||
|
// Parse back from bytes
|
||||||
|
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
|
||||||
|
// Verify all fields are preserved
|
||||||
|
assert_eq!(parsed["text"], "Test prompt");
|
||||||
|
assert_eq!(parsed["stream"], false);
|
||||||
|
assert_eq!(parsed["temperature"], 0.7);
|
||||||
|
assert_eq!(parsed["max_tokens"], 100);
|
||||||
|
assert_eq!(parsed["bootstrap_host"], "prefill1");
|
||||||
|
assert_eq!(parsed["bootstrap_port"], 9000);
|
||||||
|
assert_eq!(parsed["bootstrap_room"], 12345);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_engine_info_hostname_extraction() {
|
||||||
|
// Test various URL formats
|
||||||
|
let test_cases = vec![
|
||||||
|
("http://localhost:8080", "localhost"),
|
||||||
|
("http://10.0.0.1:8080", "10.0.0.1"),
|
||||||
|
("https://api.example.com:443", "api.example.com"),
|
||||||
|
("http://prefill-server", "prefill-server"),
|
||||||
|
("http://[::1]:8080", "["), // IPv6 edge case
|
||||||
|
("prefill:8080", "prefill"), // No protocol
|
||||||
|
];
|
||||||
|
|
||||||
|
for (url, expected_hostname) in test_cases {
|
||||||
|
let engine = EngineInfo::new_prefill(url.to_string(), None);
|
||||||
|
assert_eq!(engine.get_hostname(), expected_hostname);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_request_edge_cases() {
|
||||||
|
// Test empty request
|
||||||
|
let empty_json = json!({});
|
||||||
|
let pd_req = PDRequest::from_json(&empty_json);
|
||||||
|
assert!(!pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
|
||||||
|
// Test request with only stream field
|
||||||
|
let stream_only = json!({
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
let pd_req = PDRequest::from_json(&stream_only);
|
||||||
|
assert!(pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
|
||||||
|
// Test request with empty text array
|
||||||
|
let empty_batch = json!({
|
||||||
|
"text": []
|
||||||
|
});
|
||||||
|
let pd_req = PDRequest::from_json(&empty_batch);
|
||||||
|
assert_eq!(pd_req.batch_size, Some(0));
|
||||||
|
|
||||||
|
// Test request with non-array text (should be None)
|
||||||
|
let non_array_text = json!({
|
||||||
|
"text": "single string"
|
||||||
|
});
|
||||||
|
let pd_req = PDRequest::from_json(&non_array_text);
|
||||||
|
assert_eq!(pd_req.batch_size, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Phase 2: Background Load Monitoring Tests
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_background_load_monitoring() {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::sync::watch;
|
||||||
|
|
||||||
|
// Create a watch channel for testing
|
||||||
|
let (tx, rx) = watch::channel(HashMap::new());
|
||||||
|
|
||||||
|
// Simulate load updates
|
||||||
|
let mut loads = HashMap::new();
|
||||||
|
loads.insert("http://prefill1:8080".to_string(), 10);
|
||||||
|
loads.insert("http://prefill2:8080".to_string(), 20);
|
||||||
|
loads.insert("http://decode1:8080".to_string(), 5);
|
||||||
|
loads.insert("http://decode2:8080".to_string(), 15);
|
||||||
|
|
||||||
|
// Send the loads
|
||||||
|
tx.send(loads.clone()).unwrap();
|
||||||
|
|
||||||
|
// Verify receiver gets the update
|
||||||
|
let received_loads = rx.borrow();
|
||||||
|
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
|
||||||
|
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
|
||||||
|
assert_eq!(received_loads.get("http://decode1:8080"), Some(&5));
|
||||||
|
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_power_of_two_load_selection() {
|
||||||
|
// Test the power-of-two selection logic with different load scenarios
|
||||||
|
|
||||||
|
// Scenario 1: Clear winner for both prefill and decode
|
||||||
|
let _loads = vec![
|
||||||
|
("prefill1", 100),
|
||||||
|
("prefill2", 10), // Should be selected
|
||||||
|
("decode1", 50),
|
||||||
|
("decode2", 5), // Should be selected
|
||||||
|
];
|
||||||
|
|
||||||
|
// In actual implementation, the lower load should be selected
|
||||||
|
assert!(10 < 100);
|
||||||
|
assert!(5 < 50);
|
||||||
|
|
||||||
|
// Scenario 2: Equal loads (should select first)
|
||||||
|
let _equal_loads = vec![
|
||||||
|
("prefill1", 20),
|
||||||
|
("prefill2", 20), // Either could be selected
|
||||||
|
("decode1", 30),
|
||||||
|
("decode2", 30), // Either could be selected
|
||||||
|
];
|
||||||
|
|
||||||
|
// When loads are equal, <= comparison means first is selected
|
||||||
|
assert!(20 <= 20);
|
||||||
|
assert!(30 <= 30);
|
||||||
|
|
||||||
|
// Scenario 3: Missing load data (should default to usize::MAX)
|
||||||
|
// This tests the unwrap_or(usize::MAX) behavior
|
||||||
|
let missing_load = usize::MAX;
|
||||||
|
assert!(10 < missing_load);
|
||||||
|
assert!(missing_load > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_monitoring_configuration() {
|
||||||
|
// Test that load monitoring is only enabled for PowerOfTwo policy
|
||||||
|
let policies = vec![
|
||||||
|
(PDSelectionPolicy::Random, false),
|
||||||
|
(PDSelectionPolicy::PowerOfTwo, true),
|
||||||
|
(
|
||||||
|
PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
},
|
||||||
|
false,
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
|
for (policy, should_monitor) in policies {
|
||||||
|
match policy {
|
||||||
|
PDSelectionPolicy::PowerOfTwo => assert!(should_monitor),
|
||||||
|
_ => assert!(!should_monitor),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_watch_channel_behavior() {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use tokio::sync::watch;
|
||||||
|
|
||||||
|
// Test watch channel's broadcast behavior
|
||||||
|
let (tx, rx1) = watch::channel(HashMap::new());
|
||||||
|
let rx2 = rx1.clone();
|
||||||
|
|
||||||
|
// Initial state - empty map
|
||||||
|
assert!(rx1.borrow().is_empty());
|
||||||
|
assert!(rx2.borrow().is_empty());
|
||||||
|
|
||||||
|
// Update 1
|
||||||
|
let mut loads = HashMap::new();
|
||||||
|
loads.insert("worker1".to_string(), 10);
|
||||||
|
tx.send(loads.clone()).unwrap();
|
||||||
|
|
||||||
|
// Both receivers see the update
|
||||||
|
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
|
||||||
|
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
|
||||||
|
|
||||||
|
// Update 2 - overwrites previous
|
||||||
|
loads.insert("worker1".to_string(), 20);
|
||||||
|
loads.insert("worker2".to_string(), 30);
|
||||||
|
tx.send(loads).unwrap();
|
||||||
|
|
||||||
|
// Both receivers see the latest state
|
||||||
|
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
|
||||||
|
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Tests based on bench_one_batch_server.py patterns
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_request_formats() {
|
||||||
|
// Based on bench_one_batch_server.py request patterns
|
||||||
|
|
||||||
|
// Test 1: Batch request with input_ids (most common in benchmarks)
|
||||||
|
let batch_request = json!({
|
||||||
|
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
"ignore_eos": true,
|
||||||
|
},
|
||||||
|
"return_logprob": false,
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
let pd_req = PDRequest::from_json(&batch_request);
|
||||||
|
assert!(pd_req.is_stream);
|
||||||
|
assert_eq!(pd_req.batch_size, Some(3));
|
||||||
|
|
||||||
|
// Test 2: Request with return_logprob (critical for PD)
|
||||||
|
let logprob_request = json!({
|
||||||
|
"input_ids": [[1, 2, 3]],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
},
|
||||||
|
"return_logprob": true,
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(logprob_request["return_logprob"], true);
|
||||||
|
assert_eq!(logprob_request["stream"], false);
|
||||||
|
|
||||||
|
// Test 3: Large batch sizes from benchmark
|
||||||
|
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
|
||||||
|
for bs in batch_sizes {
|
||||||
|
let request = json!({
|
||||||
|
"input_ids": vec![vec![1, 2, 3]; bs],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
let pd_req = PDRequest::from_json(&request);
|
||||||
|
assert_eq!(pd_req.batch_size, Some(bs));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sampling_params_handling() {
|
||||||
|
// Test various sampling parameters from bench_one_batch_server.py
|
||||||
|
let sampling_params_variations = vec![
|
||||||
|
json!({
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"ignore_eos": true
|
||||||
|
}),
|
||||||
|
json!({
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
"ignore_eos": false,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"frequency_penalty": 0.5
|
||||||
|
}),
|
||||||
|
json!({
|
||||||
|
"temperature": 1.0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"json_schema": "$$ANY$$" // Structured output
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
|
||||||
|
for params in sampling_params_variations {
|
||||||
|
let request = json!({
|
||||||
|
"input_ids": [[1, 2, 3]],
|
||||||
|
"sampling_params": params.clone(),
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify params are preserved
|
||||||
|
assert_eq!(request["sampling_params"], params);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_streaming_response_parsing() {
|
||||||
|
// Test SSE format parsing from streaming responses
|
||||||
|
let sse_chunks = vec![
|
||||||
|
"data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
||||||
|
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
||||||
|
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
||||||
|
"data: [DONE]",
|
||||||
|
];
|
||||||
|
|
||||||
|
for chunk in &sse_chunks[..3] {
|
||||||
|
assert!(chunk.starts_with("data: "));
|
||||||
|
let json_str = &chunk[6..]; // Skip "data: "
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(json_str).unwrap();
|
||||||
|
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test [DONE] detection
|
||||||
|
assert_eq!(sse_chunks[3], "data: [DONE]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ttft_calculation() {
|
||||||
|
// Test Time To First Token calculation pattern
|
||||||
|
let first_token_response = json!({
|
||||||
|
"text": "Hello",
|
||||||
|
"meta_info": {
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"finish_reason": null
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// TTFT is calculated when completion_tokens == 1
|
||||||
|
assert_eq!(first_token_response["meta_info"]["completion_tokens"], 1);
|
||||||
|
assert!(first_token_response["meta_info"]["finish_reason"].is_null());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_throughput_metrics() {
|
||||||
|
// Test throughput calculation patterns from bench_one_batch_server.py
|
||||||
|
let batch_size = 16;
|
||||||
|
let input_len = 1024;
|
||||||
|
let output_len = 16;
|
||||||
|
let ttft = 0.5; // seconds
|
||||||
|
let total_latency = 2.0; // seconds
|
||||||
|
|
||||||
|
// Input throughput = batch_size * input_len / ttft
|
||||||
|
let input_throughput = (batch_size as f64) * (input_len as f64) / ttft;
|
||||||
|
assert!((input_throughput - 32768.0).abs() < 0.01);
|
||||||
|
|
||||||
|
// Output throughput = batch_size * output_len / (latency - ttft)
|
||||||
|
let output_throughput = (batch_size as f64) * (output_len as f64) / (total_latency - ttft);
|
||||||
|
assert!((output_throughput - 170.67).abs() < 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_error_response_handling() {
|
||||||
|
// Test error response format from bench_one_batch_server.py
|
||||||
|
let error_response = json!({
|
||||||
|
"error": "Request has failed. Invalid input format."
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(error_response.get("error").is_some());
|
||||||
|
assert!(error_response["error"].as_str().unwrap().contains("failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_structured_output_request() {
|
||||||
|
// Test structured output format (json_schema)
|
||||||
|
let structured_request = json!({
|
||||||
|
"text": "What is the capital of France? Answer in JSON.",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 64,
|
||||||
|
"json_schema": "$$ANY$$"
|
||||||
|
},
|
||||||
|
"stream": false
|
||||||
|
});
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
structured_request["sampling_params"]["json_schema"],
|
||||||
|
"$$ANY$$"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||||
|
// Test bootstrap injection with actual benchmark request patterns
|
||||||
|
let mut benchmark_request = json!({
|
||||||
|
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 8,
|
||||||
|
"ignore_eos": true
|
||||||
|
},
|
||||||
|
"return_logprob": true,
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
// Simulate bootstrap injection
|
||||||
|
let prefill_info = EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||||
|
let batch_size = 16;
|
||||||
|
|
||||||
|
benchmark_request["bootstrap_host"] = json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||||
|
benchmark_request["bootstrap_port"] = json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||||
|
benchmark_request["bootstrap_room"] =
|
||||||
|
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||||
|
|
||||||
|
// Verify bootstrap fields match batch size
|
||||||
|
assert_eq!(
|
||||||
|
benchmark_request["bootstrap_host"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
benchmark_request["bootstrap_port"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
benchmark_request["bootstrap_room"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify original fields are preserved
|
||||||
|
assert_eq!(benchmark_request["return_logprob"], true);
|
||||||
|
assert_eq!(benchmark_request["stream"], true);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_server_info_response_format() {
|
||||||
|
// Test server info format expected by bench_one_batch_server.py
|
||||||
|
let server_info = json!({
|
||||||
|
"internal_states": [{
|
||||||
|
"avg_spec_accept_length": 3.5,
|
||||||
|
"last_gen_throughput": 2048.5,
|
||||||
|
"load": 16
|
||||||
|
}],
|
||||||
|
"prefill": [
|
||||||
|
{"url": "http://prefill1:8080", "load": 10},
|
||||||
|
{"url": "http://prefill2:8080", "load": 20}
|
||||||
|
],
|
||||||
|
"decode": [
|
||||||
|
{"url": "http://decode1:8080", "load": 5},
|
||||||
|
{"url": "http://decode2:8080", "load": 15}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
// Verify structure matches what benchmark expects
|
||||||
|
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
|
||||||
|
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
|
||||||
|
assert!(server_info["prefill"].is_array());
|
||||||
|
assert!(server_info["decode"].is_array());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========================================================================
|
||||||
|
// Comprehensive Endpoint Coverage Test
|
||||||
|
// ========================================================================
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pd_endpoints_coverage() {
|
||||||
|
// Document all endpoints from Python mini_lb.py and verify implementation status
|
||||||
|
let implemented_endpoints = vec![
|
||||||
|
("/health", "GET", true),
|
||||||
|
("/health_generate", "GET", true), // Note: Python uses POST, we use GET
|
||||||
|
("/get_server_info", "GET", true),
|
||||||
|
("/v1/models", "GET", true),
|
||||||
|
("/get_model_info", "GET", true),
|
||||||
|
("/generate", "POST", true),
|
||||||
|
("/v1/chat/completions", "POST", true),
|
||||||
|
("/v1/completions", "POST", true),
|
||||||
|
("/flush_cache", "POST", true),
|
||||||
|
("/get_loads", "GET", true),
|
||||||
|
("/register", "POST", false), // NOT IMPLEMENTED - needs dynamic worker management
|
||||||
|
];
|
||||||
|
|
||||||
|
let implemented_count = implemented_endpoints
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, _, impl_status)| *impl_status)
|
||||||
|
.count();
|
||||||
|
let total_count = implemented_endpoints.len();
|
||||||
|
|
||||||
|
// We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2)
|
||||||
|
assert_eq!(implemented_count, 10);
|
||||||
|
assert_eq!(total_count, 11);
|
||||||
|
|
||||||
|
// Document the missing endpoint
|
||||||
|
let missing: Vec<_> = implemented_endpoints
|
||||||
|
.iter()
|
||||||
|
.filter(|(_, _, impl_status)| !impl_status)
|
||||||
|
.map(|(endpoint, method, _)| format!("{} {}", method, endpoint))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
assert_eq!(missing, vec!["POST /register"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_large_batch_bootstrap_injection() {
|
||||||
|
// Test bootstrap injection performance with very large batches
|
||||||
|
// This simulates the bench_one_batch_server.py scenario
|
||||||
|
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||||
|
|
||||||
|
for batch_size in large_batch_sizes {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
// Simulate a large batch request
|
||||||
|
let mut large_batch_request = json!({
|
||||||
|
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
"stream": true
|
||||||
|
});
|
||||||
|
|
||||||
|
// Simulate bootstrap injection
|
||||||
|
let prefill_info =
|
||||||
|
EngineInfo::new_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||||
|
|
||||||
|
large_batch_request["bootstrap_host"] =
|
||||||
|
json!(vec![prefill_info.get_hostname(); batch_size]);
|
||||||
|
large_batch_request["bootstrap_port"] =
|
||||||
|
json!(vec![prefill_info.bootstrap_port; batch_size]);
|
||||||
|
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||||
|
.map(|_| rand::thread_rng().gen::<u64>())
|
||||||
|
.collect::<Vec<_>>());
|
||||||
|
|
||||||
|
let elapsed = start.elapsed();
|
||||||
|
|
||||||
|
// Verify bootstrap fields are correctly sized
|
||||||
|
assert_eq!(
|
||||||
|
large_batch_request["bootstrap_host"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
large_batch_request["bootstrap_port"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
large_batch_request["bootstrap_room"]
|
||||||
|
.as_array()
|
||||||
|
.unwrap()
|
||||||
|
.len(),
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
|
||||||
|
// Bootstrap injection should be reasonably fast even for large batches
|
||||||
|
println!(
|
||||||
|
"Bootstrap injection for batch_size {} took {:?}",
|
||||||
|
batch_size, elapsed
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
elapsed.as_millis() < 1000,
|
||||||
|
"Bootstrap injection took too long for batch size {}",
|
||||||
|
batch_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_payload_size_calculation() {
|
||||||
|
// Test payload size estimation for bench_one_batch_server.py scenarios
|
||||||
|
let test_cases = vec![
|
||||||
|
(1, 1024, 16), // Small batch
|
||||||
|
(16, 1024, 16), // Medium batch
|
||||||
|
(64, 1024, 16), // Large batch
|
||||||
|
(8192, 4096, 5), // Benchmark scenario
|
||||||
|
];
|
||||||
|
|
||||||
|
for (batch_size, input_len, _output_len) in test_cases {
|
||||||
|
// Estimate payload size (rough calculation)
|
||||||
|
// Each token is ~4 bytes (i32), plus JSON overhead
|
||||||
|
let tokens_size = batch_size * input_len * 4; // 4 bytes per token
|
||||||
|
let json_overhead = batch_size * 100; // ~100 bytes overhead per request
|
||||||
|
let total_size = tokens_size + json_overhead;
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Batch size: {}, Input len: {}, Estimated payload: {} MB",
|
||||||
|
batch_size,
|
||||||
|
input_len,
|
||||||
|
total_size / (1024 * 1024)
|
||||||
|
);
|
||||||
|
|
||||||
|
// For the benchmark case (8192, 4096), this should be ~134 MB
|
||||||
|
if batch_size == 8192 && input_len == 4096 {
|
||||||
|
assert!(
|
||||||
|
total_size > 100 * 1024 * 1024,
|
||||||
|
"Benchmark payload should be > 100MB"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
total_size < 200 * 1024 * 1024,
|
||||||
|
"Benchmark payload should be < 200MB"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||||
|
// Document the mapping from PolicyType to PDSelectionPolicy
|
||||||
|
// This mapping happens in lib.rs when pd_disaggregated=true
|
||||||
|
|
||||||
|
// PolicyType::Random -> PDSelectionPolicy::Random
|
||||||
|
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
|
||||||
|
// PolicyType::CacheAware -> PDSelectionPolicy::CacheAware { ... }
|
||||||
|
// PolicyType::RoundRobin -> ERROR (not supported in PD mode)
|
||||||
|
|
||||||
|
// Test that PDSelectionPolicy doesn't include RoundRobin
|
||||||
|
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
||||||
|
assert_eq!(
|
||||||
|
pd_policy_count, 3,
|
||||||
|
"PDSelectionPolicy should have exactly 3 variants"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify that each PDSelectionPolicy variant can be created
|
||||||
|
let _random = PDSelectionPolicy::Random;
|
||||||
|
let _po2 = PDSelectionPolicy::PowerOfTwo;
|
||||||
|
let _cache_aware = PDSelectionPolicy::CacheAware {
|
||||||
|
cache_threshold: 0.5,
|
||||||
|
balance_abs_threshold: 32,
|
||||||
|
balance_rel_threshold: 1.1,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user