[misc] Add PD service discovery support in router (#7361)
This commit is contained in:
@@ -35,6 +35,7 @@ metrics = "0.24.2"
|
|||||||
metrics-exporter-prometheus = "0.17.0"
|
metrics-exporter-prometheus = "0.17.0"
|
||||||
# Added for request tracing
|
# Added for request tracing
|
||||||
uuid = { version = "1.10", features = ["v4", "serde"] }
|
uuid = { version = "1.10", features = ["v4", "serde"] }
|
||||||
|
thiserror = "2.0.12"
|
||||||
[profile.release]
|
[profile.release]
|
||||||
lto = "thin"
|
lto = "thin"
|
||||||
codegen-units = 1
|
codegen-units = 1
|
||||||
|
|||||||
@@ -95,38 +95,217 @@ python -m sglang_router.launch_router \
|
|||||||
|
|
||||||
### Kubernetes Service Discovery
|
### Kubernetes Service Discovery
|
||||||
|
|
||||||
SGL Router supports automatic service discovery for worker nodes in Kubernetes environments. When enabled, the router will automatically:
|
SGL Router supports automatic service discovery for worker nodes in Kubernetes environments. This feature works with both regular (single-server) routing and PD (Prefill-Decode) routing modes. When enabled, the router will automatically:
|
||||||
|
|
||||||
- Discover and add worker pods with matching labels
|
- Discover and add worker pods with matching labels
|
||||||
- Remove unhealthy or deleted worker pods
|
- Remove unhealthy or deleted worker pods
|
||||||
- Dynamically adjust the worker pool based on pod health and availability
|
- Dynamically adjust the worker pool based on pod health and availability
|
||||||
|
- For PD mode: distinguish between prefill and decode servers based on labels
|
||||||
|
|
||||||
#### Command Line Usage
|
#### Regular Mode Service Discovery
|
||||||
|
|
||||||
|
For traditional single-server routing:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m sglang_router.launch_router \
|
python -m sglang_router.launch_router \
|
||||||
--service-discovery \
|
--service-discovery \
|
||||||
--selector app=sglang-worker role=inference \
|
--selector app=sglang-worker role=inference \
|
||||||
--service-discovery-port 8000 \
|
|
||||||
--service-discovery-namespace default
|
--service-discovery-namespace default
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### PD Mode Service Discovery
|
||||||
|
|
||||||
|
For PD (Prefill-Decode) disaggregated routing, service discovery can automatically discover and classify pods as either prefill or decode servers based on their labels:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--pd-disaggregation \
|
||||||
|
--policy cache_aware \
|
||||||
|
--service-discovery \
|
||||||
|
--prefill-selector app=sglang component=prefill \
|
||||||
|
--decode-selector app=sglang component=decode \
|
||||||
|
--service-discovery-namespace sglang-system
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also specify initial prefill and decode servers and let service discovery add more:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--pd-disaggregation \
|
||||||
|
--policy cache_aware \
|
||||||
|
--prefill http://prefill-1:8000 8001 \
|
||||||
|
--decode http://decode-1:8000 \
|
||||||
|
--service-discovery \
|
||||||
|
--prefill-selector app=sglang component=prefill \
|
||||||
|
--decode-selector app=sglang component=decode \
|
||||||
|
--service-discovery-namespace sglang-system
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Kubernetes Pod Configuration for PD Mode
|
||||||
|
|
||||||
|
When using PD service discovery, your Kubernetes pods need specific labels to be classified as prefill or decode servers:
|
||||||
|
|
||||||
|
**Prefill Server Pod:**
|
||||||
|
```yaml
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Pod
|
||||||
|
metadata:
|
||||||
|
name: sglang-prefill-1
|
||||||
|
labels:
|
||||||
|
app: sglang
|
||||||
|
component: prefill
|
||||||
|
annotations:
|
||||||
|
sglang.ai/bootstrap-port: "9001" # Optional: Bootstrap port for Mooncake prefill coordination
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: sglang
|
||||||
|
image: lmsys/sglang:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8000 # Main API port
|
||||||
|
- containerPort: 9001 # Optional: Bootstrap coordination port
|
||||||
|
# ... rest of configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
**Decode Server Pod:**
|
||||||
|
```yaml
|
||||||
|
apiVersion: v1
|
||||||
|
kind: Pod
|
||||||
|
metadata:
|
||||||
|
name: sglang-decode-1
|
||||||
|
labels:
|
||||||
|
app: sglang
|
||||||
|
component: decode
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: sglang
|
||||||
|
image: lmsys/sglang:latest
|
||||||
|
ports:
|
||||||
|
- containerPort: 8000 # Main API port
|
||||||
|
# ... rest of configuration
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Requirements:**
|
||||||
|
- Prefill pods must have labels matching your `--prefill-selector`
|
||||||
|
- Decode pods must have labels matching your `--decode-selector`
|
||||||
|
- Prefill pods can optionally include bootstrap port in annotations using `sglang.ai/bootstrap-port` (defaults to None if not specified)
|
||||||
|
|
||||||
#### Service Discovery Arguments
|
#### Service Discovery Arguments
|
||||||
|
|
||||||
|
**General Arguments:**
|
||||||
- `--service-discovery`: Enable Kubernetes service discovery feature
|
- `--service-discovery`: Enable Kubernetes service discovery feature
|
||||||
- `--selector`: One or more label key-value pairs for pod selection (format: key1=value1 key2=value2)
|
- `--service-discovery-port`: Port to use when generating worker URLs (default: 8000)
|
||||||
- `--service-discovery-port`: Port to use when generating worker URLs (default: 80)
|
|
||||||
- `--service-discovery-namespace`: Optional. Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)
|
- `--service-discovery-namespace`: Optional. Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)
|
||||||
|
- `--selector`: One or more label key-value pairs for pod selection in regular mode (format: key1=value1 key2=value2)
|
||||||
|
|
||||||
|
**PD Mode Arguments:**
|
||||||
|
- `--pd-disaggregation`: Enable PD (Prefill-Decode) disaggregated mode
|
||||||
|
- `--prefill`: Specify initial prefill server URL and bootstrap port (format: URL BOOTSTRAP_PORT, can be used multiple times)
|
||||||
|
- `--decode`: Specify initial decode server URL (can be used multiple times)
|
||||||
|
- `--prefill-selector`: Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)
|
||||||
|
- `--decode-selector`: Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)
|
||||||
|
- `--policy`: Routing policy (cache_aware, random, power_of_two - note: power_of_two only works in PD mode)
|
||||||
|
|
||||||
|
**Notes:**
|
||||||
|
- Bootstrap port annotation is automatically set to `sglang.ai/bootstrap-port` for Mooncake deployments
|
||||||
|
- Advanced cache tuning parameters use sensible defaults and are not exposed via CLI
|
||||||
|
|
||||||
#### RBAC Requirements
|
#### RBAC Requirements
|
||||||
|
|
||||||
When using service discovery, you must configure proper Kubernetes RBAC permissions:
|
When using service discovery, you must configure proper Kubernetes RBAC permissions:
|
||||||
|
|
||||||
- **If using namespace-scoped discovery** (with `--service-discovery-namespace`):
|
**Namespace-scoped (recommended):**
|
||||||
Set up a ServiceAccount, Role, and RoleBinding
|
```yaml
|
||||||
|
apiVersion: v1
|
||||||
|
kind: ServiceAccount
|
||||||
|
metadata:
|
||||||
|
name: sglang-router
|
||||||
|
namespace: sglang-system
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: Role
|
||||||
|
metadata:
|
||||||
|
namespace: sglang-system
|
||||||
|
name: sglang-router
|
||||||
|
rules:
|
||||||
|
- apiGroups: [""]
|
||||||
|
resources: ["pods"]
|
||||||
|
verbs: ["get", "list", "watch"]
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: RoleBinding
|
||||||
|
metadata:
|
||||||
|
name: sglang-router
|
||||||
|
namespace: sglang-system
|
||||||
|
subjects:
|
||||||
|
- kind: ServiceAccount
|
||||||
|
name: sglang-router
|
||||||
|
namespace: sglang-system
|
||||||
|
roleRef:
|
||||||
|
kind: Role
|
||||||
|
name: sglang-router
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
```
|
||||||
|
|
||||||
- **If watching all namespaces** (without specifying namespace):
|
**Cluster-wide (if watching all namespaces):**
|
||||||
Set up a ServiceAccount, ClusterRole, and ClusterRoleBinding with permissions to list/watch pods at the cluster level
|
```yaml
|
||||||
|
apiVersion: v1
|
||||||
|
kind: ServiceAccount
|
||||||
|
metadata:
|
||||||
|
name: sglang-router
|
||||||
|
namespace: sglang-system
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRole
|
||||||
|
metadata:
|
||||||
|
name: sglang-router
|
||||||
|
rules:
|
||||||
|
- apiGroups: [""]
|
||||||
|
resources: ["pods"]
|
||||||
|
verbs: ["get", "list", "watch"]
|
||||||
|
---
|
||||||
|
apiVersion: rbac.authorization.k8s.io/v1
|
||||||
|
kind: ClusterRoleBinding
|
||||||
|
metadata:
|
||||||
|
name: sglang-router
|
||||||
|
subjects:
|
||||||
|
- kind: ServiceAccount
|
||||||
|
name: sglang-router
|
||||||
|
namespace: sglang-system
|
||||||
|
roleRef:
|
||||||
|
kind: ClusterRole
|
||||||
|
name: sglang-router
|
||||||
|
apiGroup: rbac.authorization.k8s.io
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Complete Example: PD Mode with Service Discovery
|
||||||
|
|
||||||
|
Here's a complete example of running SGLang Router with PD mode and service discovery:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start the router with PD mode and automatic prefill/decode discovery
|
||||||
|
python -m sglang_router.launch_router \
|
||||||
|
--pd-disaggregation \
|
||||||
|
--policy cache_aware \
|
||||||
|
--service-discovery \
|
||||||
|
--prefill-selector app=sglang component=prefill environment=production \
|
||||||
|
--decode-selector app=sglang component=decode environment=production \
|
||||||
|
--service-discovery-namespace production \
|
||||||
|
--host 0.0.0.0 \
|
||||||
|
--port 8080 \
|
||||||
|
--prometheus-host 0.0.0.0 \
|
||||||
|
--prometheus-port 9090
|
||||||
|
```
|
||||||
|
|
||||||
|
This setup will:
|
||||||
|
1. Enable PD (Prefill-Decode) disaggregated routing mode with automatic pod classification
|
||||||
|
2. Watch for pods in the `production` namespace
|
||||||
|
3. Automatically add prefill servers with labels `app=sglang`, `component=prefill`, `environment=production`
|
||||||
|
4. Automatically add decode servers with labels `app=sglang`, `component=decode`, `environment=production`
|
||||||
|
5. Extract bootstrap ports from the `sglang.ai/bootstrap-port` annotation on prefill pods
|
||||||
|
6. Use cache-aware load balancing for optimal performance
|
||||||
|
7. Expose the router API on port 8080 and metrics on port 9090
|
||||||
|
|
||||||
|
**Note:** In PD mode with service discovery, pods MUST match either the prefill or decode selector to be added. Pods that don't match either selector are ignored.
|
||||||
|
|
||||||
### Troubleshooting
|
### Troubleshooting
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ class RouterArgs:
|
|||||||
port: int = 30000
|
port: int = 30000
|
||||||
|
|
||||||
# PD-specific configuration
|
# PD-specific configuration
|
||||||
pd_disaggregated: bool = False # Enable PD disaggregated mode
|
pd_disaggregation: bool = False # Enable PD disaggregated mode
|
||||||
prefill_urls: List[tuple] = dataclasses.field(
|
prefill_urls: List[tuple] = dataclasses.field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
) # List of (url, bootstrap_port)
|
) # List of (url, bootstrap_port)
|
||||||
@@ -55,6 +55,10 @@ class RouterArgs:
|
|||||||
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||||
service_discovery_port: int = 80
|
service_discovery_port: int = 80
|
||||||
service_discovery_namespace: Optional[str] = None
|
service_discovery_namespace: Optional[str] = None
|
||||||
|
# PD service discovery configuration
|
||||||
|
prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||||
|
decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict)
|
||||||
|
bootstrap_port_annotation: str = "sglang.ai/bootstrap-port"
|
||||||
# Prometheus configuration
|
# Prometheus configuration
|
||||||
prometheus_port: Optional[int] = None
|
prometheus_port: Optional[int] = None
|
||||||
prometheus_host: Optional[str] = None
|
prometheus_host: Optional[str] = None
|
||||||
@@ -108,7 +112,7 @@ class RouterArgs:
|
|||||||
|
|
||||||
# PD-specific arguments
|
# PD-specific arguments
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{prefix}pd-disaggregated",
|
f"--{prefix}pd-disaggregation",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable PD (Prefill-Decode) disaggregated mode",
|
help="Enable PD (Prefill-Decode) disaggregated mode",
|
||||||
)
|
)
|
||||||
@@ -207,6 +211,18 @@ class RouterArgs:
|
|||||||
type=str,
|
type=str,
|
||||||
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}prefill-selector",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}decode-selector",
|
||||||
|
type=str,
|
||||||
|
nargs="+",
|
||||||
|
help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)",
|
||||||
|
)
|
||||||
# Prometheus configuration
|
# Prometheus configuration
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{prefix}prometheus-port",
|
f"--{prefix}prometheus-port",
|
||||||
@@ -243,7 +259,7 @@ class RouterArgs:
|
|||||||
worker_urls=worker_urls,
|
worker_urls=worker_urls,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
pd_disaggregated=getattr(args, f"{prefix}pd_disaggregated", False),
|
pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False),
|
||||||
prefill_urls=prefill_urls,
|
prefill_urls=prefill_urls,
|
||||||
decode_urls=decode_urls,
|
decode_urls=decode_urls,
|
||||||
policy=getattr(args, f"{prefix}policy"),
|
policy=getattr(args, f"{prefix}policy"),
|
||||||
@@ -267,6 +283,13 @@ class RouterArgs:
|
|||||||
service_discovery_namespace=getattr(
|
service_discovery_namespace=getattr(
|
||||||
args, f"{prefix}service_discovery_namespace", None
|
args, f"{prefix}service_discovery_namespace", None
|
||||||
),
|
),
|
||||||
|
prefill_selector=cls._parse_selector(
|
||||||
|
getattr(args, f"{prefix}prefill_selector", None)
|
||||||
|
),
|
||||||
|
decode_selector=cls._parse_selector(
|
||||||
|
getattr(args, f"{prefix}decode_selector", None)
|
||||||
|
),
|
||||||
|
bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation
|
||||||
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
prometheus_port=getattr(args, f"{prefix}prometheus_port", None),
|
||||||
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
prometheus_host=getattr(args, f"{prefix}prometheus_host", None),
|
||||||
)
|
)
|
||||||
@@ -355,17 +378,20 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
router_args = args
|
router_args = args
|
||||||
|
|
||||||
# Validate configuration based on mode
|
# Validate configuration based on mode
|
||||||
if router_args.pd_disaggregated:
|
if router_args.pd_disaggregation:
|
||||||
# Validate PD configuration
|
# Validate PD configuration - skip URL requirements if using service discovery
|
||||||
|
if not router_args.service_discovery:
|
||||||
if not router_args.prefill_urls:
|
if not router_args.prefill_urls:
|
||||||
raise ValueError("PD disaggregated mode requires --prefill")
|
raise ValueError("PD disaggregation mode requires --prefill")
|
||||||
if not router_args.decode_urls:
|
if not router_args.decode_urls:
|
||||||
raise ValueError("PD disaggregated mode requires --decode")
|
raise ValueError("PD disaggregation mode requires --decode")
|
||||||
|
|
||||||
# Create router with unified constructor
|
# Create router with unified constructor
|
||||||
router = Router(
|
router = Router(
|
||||||
worker_urls=(
|
worker_urls=(
|
||||||
router_args.worker_urls if not router_args.pd_disaggregated else []
|
[]
|
||||||
|
if router_args.service_discovery or router_args.pd_disaggregation
|
||||||
|
else router_args.worker_urls
|
||||||
),
|
),
|
||||||
host=router_args.host,
|
host=router_args.host,
|
||||||
port=router_args.port,
|
port=router_args.port,
|
||||||
@@ -384,14 +410,16 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
selector=router_args.selector,
|
selector=router_args.selector,
|
||||||
service_discovery_port=router_args.service_discovery_port,
|
service_discovery_port=router_args.service_discovery_port,
|
||||||
service_discovery_namespace=router_args.service_discovery_namespace,
|
service_discovery_namespace=router_args.service_discovery_namespace,
|
||||||
|
prefill_selector=router_args.prefill_selector,
|
||||||
|
decode_selector=router_args.decode_selector,
|
||||||
prometheus_port=router_args.prometheus_port,
|
prometheus_port=router_args.prometheus_port,
|
||||||
prometheus_host=router_args.prometheus_host,
|
prometheus_host=router_args.prometheus_host,
|
||||||
pd_disaggregated=router_args.pd_disaggregated,
|
pd_disaggregation=router_args.pd_disaggregation,
|
||||||
prefill_urls=(
|
prefill_urls=(
|
||||||
router_args.prefill_urls if router_args.pd_disaggregated else None
|
router_args.prefill_urls if router_args.pd_disaggregation else None
|
||||||
),
|
),
|
||||||
decode_urls=(
|
decode_urls=(
|
||||||
router_args.decode_urls if router_args.pd_disaggregated else None
|
router_args.decode_urls if router_args.pd_disaggregation else None
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -425,7 +453,7 @@ Examples:
|
|||||||
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000
|
||||||
|
|
||||||
# PD disaggregated mode
|
# PD disaggregated mode
|
||||||
python -m sglang_router.launch_router --pd-disaggregated \\
|
python -m sglang_router.launch_router --pd-disaggregation \\
|
||||||
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
--prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\
|
||||||
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
--decode http://decode1:8001 --decode http://decode2:8001 \\
|
||||||
--policy cache_aware
|
--policy cache_aware
|
||||||
|
|||||||
@@ -41,9 +41,13 @@ class Router:
|
|||||||
worker URLs using this port. Default: 80
|
worker URLs using this port. Default: 80
|
||||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
|
service_discovery_namespace: Kubernetes namespace to watch for pods. If not provided,
|
||||||
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
watches pods across all namespaces (requires cluster-wide permissions). Default: None
|
||||||
|
prefill_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
||||||
|
for prefill servers (PD mode only). Default: {}
|
||||||
|
decode_selector: Dictionary mapping of label keys to values for Kubernetes pod selection
|
||||||
|
for decode servers (PD mode only). Default: {}
|
||||||
prometheus_port: Port to expose Prometheus metrics. Default: None
|
prometheus_port: Port to expose Prometheus metrics. Default: None
|
||||||
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
prometheus_host: Host address to bind the Prometheus metrics server. Default: None
|
||||||
pd_disaggregated: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
pd_disaggregation: Enable PD (Prefill-Decode) disaggregated mode. Default: False
|
||||||
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
prefill_urls: List of (url, bootstrap_port) tuples for prefill servers (PD mode only)
|
||||||
decode_urls: List of URLs for decode servers (PD mode only)
|
decode_urls: List of URLs for decode servers (PD mode only)
|
||||||
"""
|
"""
|
||||||
@@ -68,14 +72,20 @@ class Router:
|
|||||||
selector: Dict[str, str] = None,
|
selector: Dict[str, str] = None,
|
||||||
service_discovery_port: int = 80,
|
service_discovery_port: int = 80,
|
||||||
service_discovery_namespace: Optional[str] = None,
|
service_discovery_namespace: Optional[str] = None,
|
||||||
|
prefill_selector: Dict[str, str] = None,
|
||||||
|
decode_selector: Dict[str, str] = None,
|
||||||
prometheus_port: Optional[int] = None,
|
prometheus_port: Optional[int] = None,
|
||||||
prometheus_host: Optional[str] = None,
|
prometheus_host: Optional[str] = None,
|
||||||
pd_disaggregated: bool = False,
|
pd_disaggregation: bool = False,
|
||||||
prefill_urls: Optional[List[tuple]] = None,
|
prefill_urls: Optional[List[tuple]] = None,
|
||||||
decode_urls: Optional[List[str]] = None,
|
decode_urls: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
if selector is None:
|
if selector is None:
|
||||||
selector = {}
|
selector = {}
|
||||||
|
if prefill_selector is None:
|
||||||
|
prefill_selector = {}
|
||||||
|
if decode_selector is None:
|
||||||
|
decode_selector = {}
|
||||||
|
|
||||||
self._router = _Router(
|
self._router = _Router(
|
||||||
worker_urls=worker_urls,
|
worker_urls=worker_urls,
|
||||||
@@ -96,9 +106,11 @@ class Router:
|
|||||||
selector=selector,
|
selector=selector,
|
||||||
service_discovery_port=service_discovery_port,
|
service_discovery_port=service_discovery_port,
|
||||||
service_discovery_namespace=service_discovery_namespace,
|
service_discovery_namespace=service_discovery_namespace,
|
||||||
|
prefill_selector=prefill_selector,
|
||||||
|
decode_selector=decode_selector,
|
||||||
prometheus_port=prometheus_port,
|
prometheus_port=prometheus_port,
|
||||||
prometheus_host=prometheus_host,
|
prometheus_host=prometheus_host,
|
||||||
pd_disaggregated=pd_disaggregated,
|
pd_disaggregation=pd_disaggregation,
|
||||||
prefill_urls=prefill_urls,
|
prefill_urls=prefill_urls,
|
||||||
decode_urls=decode_urls,
|
decode_urls=decode_urls,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
prometheus_port=None,
|
prometheus_port=None,
|
||||||
prometheus_host=None,
|
prometheus_host=None,
|
||||||
# PD-specific attributes
|
# PD-specific attributes
|
||||||
pd_disaggregated=False,
|
pd_disaggregation=False,
|
||||||
prefill=None,
|
prefill=None,
|
||||||
decode=None,
|
decode=None,
|
||||||
# Keep worker_urls for regular mode
|
# Keep worker_urls for regular mode
|
||||||
@@ -119,7 +119,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
# Test RouterArgs parsing for PD mode
|
# Test RouterArgs parsing for PD mode
|
||||||
# Simulate the parsed args structure from argparse with action="append"
|
# Simulate the parsed args structure from argparse with action="append"
|
||||||
args = self.create_router_args(
|
args = self.create_router_args(
|
||||||
pd_disaggregated=True,
|
pd_disaggregation=True,
|
||||||
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
policy="power_of_two", # PowerOfTwo is only valid in PD mode
|
||||||
prefill=[
|
prefill=[
|
||||||
["http://prefill1:8080", "9000"],
|
["http://prefill1:8080", "9000"],
|
||||||
@@ -133,7 +133,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
router_args = RouterArgs.from_cli_args(args)
|
router_args = RouterArgs.from_cli_args(args)
|
||||||
self.assertTrue(router_args.pd_disaggregated)
|
self.assertTrue(router_args.pd_disaggregation)
|
||||||
self.assertEqual(router_args.policy, "power_of_two")
|
self.assertEqual(router_args.policy, "power_of_two")
|
||||||
self.assertEqual(len(router_args.prefill_urls), 2)
|
self.assertEqual(len(router_args.prefill_urls), 2)
|
||||||
self.assertEqual(len(router_args.decode_urls), 2)
|
self.assertEqual(len(router_args.decode_urls), 2)
|
||||||
@@ -147,7 +147,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
# Test Router creation in PD mode
|
# Test Router creation in PD mode
|
||||||
router = Router(
|
router = Router(
|
||||||
worker_urls=[], # Empty for PD mode
|
worker_urls=[], # Empty for PD mode
|
||||||
pd_disaggregated=True,
|
pd_disaggregation=True,
|
||||||
prefill_urls=[
|
prefill_urls=[
|
||||||
("http://prefill1:8080", 9000),
|
("http://prefill1:8080", 9000),
|
||||||
("http://prefill2:8080", None),
|
("http://prefill2:8080", None),
|
||||||
@@ -165,7 +165,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
|
|
||||||
# Test 1: PowerOfTwo is only valid in PD mode
|
# Test 1: PowerOfTwo is only valid in PD mode
|
||||||
args = self.create_router_args(
|
args = self.create_router_args(
|
||||||
pd_disaggregated=False,
|
pd_disaggregation=False,
|
||||||
policy="power_of_two",
|
policy="power_of_two",
|
||||||
worker_urls=["http://localhost:8000"],
|
worker_urls=["http://localhost:8000"],
|
||||||
)
|
)
|
||||||
@@ -180,7 +180,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
|
|
||||||
# Test 2: RoundRobin is not valid in PD mode
|
# Test 2: RoundRobin is not valid in PD mode
|
||||||
args = self.create_router_args(
|
args = self.create_router_args(
|
||||||
pd_disaggregated=True,
|
pd_disaggregation=True,
|
||||||
policy="round_robin",
|
policy="round_robin",
|
||||||
prefill=[["http://prefill1:8080", "9000"]],
|
prefill=[["http://prefill1:8080", "9000"]],
|
||||||
decode=[["http://decode1:8081"]],
|
decode=[["http://decode1:8081"]],
|
||||||
@@ -198,7 +198,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
# Test 3: Valid combinations should not raise errors
|
# Test 3: Valid combinations should not raise errors
|
||||||
# Regular mode with RoundRobin
|
# Regular mode with RoundRobin
|
||||||
args = self.create_router_args(
|
args = self.create_router_args(
|
||||||
pd_disaggregated=False,
|
pd_disaggregation=False,
|
||||||
policy="round_robin",
|
policy="round_robin",
|
||||||
worker_urls=["http://localhost:8000"],
|
worker_urls=["http://localhost:8000"],
|
||||||
)
|
)
|
||||||
@@ -206,7 +206,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
|
|
||||||
# PD mode with PowerOfTwo
|
# PD mode with PowerOfTwo
|
||||||
args = self.create_router_args(
|
args = self.create_router_args(
|
||||||
pd_disaggregated=True,
|
pd_disaggregation=True,
|
||||||
policy="power_of_two",
|
policy="power_of_two",
|
||||||
prefill=[["http://prefill1:8080", "9000"]],
|
prefill=[["http://prefill1:8080", "9000"]],
|
||||||
decode=[["http://decode1:8081"]],
|
decode=[["http://decode1:8081"]],
|
||||||
@@ -214,6 +214,79 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
# This should not raise (though it may fail to connect)
|
# This should not raise (though it may fail to connect)
|
||||||
|
|
||||||
|
def test_pd_service_discovery_args_parsing(self):
|
||||||
|
"""Test PD service discovery CLI argument parsing."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from sglang_router.launch_router import RouterArgs
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
RouterArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--pd-disaggregation",
|
||||||
|
"--service-discovery",
|
||||||
|
"--prefill-selector",
|
||||||
|
"app=sglang",
|
||||||
|
"component=prefill",
|
||||||
|
"--decode-selector",
|
||||||
|
"app=sglang",
|
||||||
|
"component=decode",
|
||||||
|
"--service-discovery-port",
|
||||||
|
"8000",
|
||||||
|
"--service-discovery-namespace",
|
||||||
|
"production",
|
||||||
|
"--policy",
|
||||||
|
"cache_aware",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
router_args = RouterArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
self.assertTrue(router_args.pd_disaggregation)
|
||||||
|
self.assertTrue(router_args.service_discovery)
|
||||||
|
self.assertEqual(
|
||||||
|
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
router_args.decode_selector, {"app": "sglang", "component": "decode"}
|
||||||
|
)
|
||||||
|
self.assertEqual(router_args.service_discovery_port, 8000)
|
||||||
|
self.assertEqual(router_args.service_discovery_namespace, "production")
|
||||||
|
|
||||||
|
def test_regular_service_discovery_args_parsing(self):
|
||||||
|
"""Test regular mode service discovery CLI argument parsing."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from sglang_router.launch_router import RouterArgs
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
RouterArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
args = parser.parse_args(
|
||||||
|
[
|
||||||
|
"--service-discovery",
|
||||||
|
"--selector",
|
||||||
|
"app=sglang-worker",
|
||||||
|
"environment=staging",
|
||||||
|
"--service-discovery-port",
|
||||||
|
"8000",
|
||||||
|
"--policy",
|
||||||
|
"round_robin",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
router_args = RouterArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
self.assertFalse(router_args.pd_disaggregation)
|
||||||
|
self.assertTrue(router_args.service_discovery)
|
||||||
|
self.assertEqual(
|
||||||
|
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
|
||||||
|
)
|
||||||
|
self.assertEqual(router_args.prefill_selector, {})
|
||||||
|
self.assertEqual(router_args.decode_selector, {})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -42,12 +42,16 @@ struct Router {
|
|||||||
selector: HashMap<String, String>,
|
selector: HashMap<String, String>,
|
||||||
service_discovery_port: u16,
|
service_discovery_port: u16,
|
||||||
service_discovery_namespace: Option<String>,
|
service_discovery_namespace: Option<String>,
|
||||||
|
// PD service discovery fields
|
||||||
|
prefill_selector: HashMap<String, String>,
|
||||||
|
decode_selector: HashMap<String, String>,
|
||||||
|
bootstrap_port_annotation: String,
|
||||||
prometheus_port: Option<u16>,
|
prometheus_port: Option<u16>,
|
||||||
prometheus_host: Option<String>,
|
prometheus_host: Option<String>,
|
||||||
request_timeout_secs: u64,
|
request_timeout_secs: u64,
|
||||||
// PD mode flag
|
// PD mode flag
|
||||||
pd_disaggregated: bool,
|
pd_disaggregation: bool,
|
||||||
// PD-specific fields (only used when pd_disaggregated is true)
|
// PD-specific fields (only used when pd_disaggregation is true)
|
||||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
@@ -74,10 +78,13 @@ impl Router {
|
|||||||
selector = HashMap::new(),
|
selector = HashMap::new(),
|
||||||
service_discovery_port = 80,
|
service_discovery_port = 80,
|
||||||
service_discovery_namespace = None,
|
service_discovery_namespace = None,
|
||||||
|
prefill_selector = HashMap::new(),
|
||||||
|
decode_selector = HashMap::new(),
|
||||||
|
bootstrap_port_annotation = String::from("sglang.ai/bootstrap-port"),
|
||||||
prometheus_port = None,
|
prometheus_port = None,
|
||||||
prometheus_host = None,
|
prometheus_host = None,
|
||||||
request_timeout_secs = 600, // Add configurable request timeout
|
request_timeout_secs = 600, // Add configurable request timeout
|
||||||
pd_disaggregated = false, // New flag for PD mode
|
pd_disaggregation = false, // New flag for PD mode
|
||||||
prefill_urls = None,
|
prefill_urls = None,
|
||||||
decode_urls = None
|
decode_urls = None
|
||||||
))]
|
))]
|
||||||
@@ -100,10 +107,13 @@ impl Router {
|
|||||||
selector: HashMap<String, String>,
|
selector: HashMap<String, String>,
|
||||||
service_discovery_port: u16,
|
service_discovery_port: u16,
|
||||||
service_discovery_namespace: Option<String>,
|
service_discovery_namespace: Option<String>,
|
||||||
|
prefill_selector: HashMap<String, String>,
|
||||||
|
decode_selector: HashMap<String, String>,
|
||||||
|
bootstrap_port_annotation: String,
|
||||||
prometheus_port: Option<u16>,
|
prometheus_port: Option<u16>,
|
||||||
prometheus_host: Option<String>,
|
prometheus_host: Option<String>,
|
||||||
request_timeout_secs: u64,
|
request_timeout_secs: u64,
|
||||||
pd_disaggregated: bool,
|
pd_disaggregation: bool,
|
||||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||||
decode_urls: Option<Vec<String>>,
|
decode_urls: Option<Vec<String>>,
|
||||||
) -> PyResult<Self> {
|
) -> PyResult<Self> {
|
||||||
@@ -126,17 +136,20 @@ impl Router {
|
|||||||
selector,
|
selector,
|
||||||
service_discovery_port,
|
service_discovery_port,
|
||||||
service_discovery_namespace,
|
service_discovery_namespace,
|
||||||
|
prefill_selector,
|
||||||
|
decode_selector,
|
||||||
|
bootstrap_port_annotation,
|
||||||
prometheus_port,
|
prometheus_port,
|
||||||
prometheus_host,
|
prometheus_host,
|
||||||
request_timeout_secs,
|
request_timeout_secs,
|
||||||
pd_disaggregated,
|
pd_disaggregation,
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
decode_urls,
|
decode_urls,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn start(&self) -> PyResult<()> {
|
fn start(&self) -> PyResult<()> {
|
||||||
let policy_config = if self.pd_disaggregated {
|
let policy_config = if self.pd_disaggregation {
|
||||||
// PD mode - map PolicyType to PDSelectionPolicy
|
// PD mode - map PolicyType to PDSelectionPolicy
|
||||||
let pd_selection_policy = match &self.policy {
|
let pd_selection_policy = match &self.policy {
|
||||||
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
PolicyType::Random => pd_types::PDSelectionPolicy::Random,
|
||||||
@@ -207,6 +220,11 @@ impl Router {
|
|||||||
check_interval: std::time::Duration::from_secs(60),
|
check_interval: std::time::Duration::from_secs(60),
|
||||||
port: self.service_discovery_port,
|
port: self.service_discovery_port,
|
||||||
namespace: self.service_discovery_namespace.clone(),
|
namespace: self.service_discovery_namespace.clone(),
|
||||||
|
// PD mode configuration
|
||||||
|
pd_mode: self.pd_disaggregation,
|
||||||
|
prefill_selector: self.prefill_selector.clone(),
|
||||||
|
decode_selector: self.decode_selector.clone(),
|
||||||
|
bootstrap_port_annotation: self.bootstrap_port_annotation.clone(),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
// PD (Prefill-Decode) Router Implementation
|
// PD (Prefill-Decode) Router Implementation
|
||||||
// This module handles routing for disaggregated prefill-decode systems
|
// This module handles routing for disaggregated prefill-decode systems
|
||||||
|
|
||||||
use crate::pd_types::{Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDSelectionPolicy};
|
use crate::pd_types::{
|
||||||
|
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||||
|
};
|
||||||
use crate::tree::Tree;
|
use crate::tree::Tree;
|
||||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||||
use actix_web::{HttpRequest, HttpResponse};
|
use actix_web::{HttpRequest, HttpResponse};
|
||||||
@@ -65,12 +67,145 @@ impl Drop for LoadGuard<'_> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl PDRouter {
|
impl PDRouter {
|
||||||
// TODO: Add methods for dynamic worker management to support /register endpoint:
|
// Dynamic worker management methods for service discovery
|
||||||
// - add_prefill_server(url: String, bootstrap_port: Option<u16>)
|
pub async fn add_prefill_server(
|
||||||
// - add_decode_server(url: String)
|
&self,
|
||||||
// - remove_prefill_server(url: &str)
|
url: String,
|
||||||
// - remove_decode_server(url: &str)
|
bootstrap_port: Option<u16>,
|
||||||
// These methods will be used when service discovery is implemented for PD mode
|
) -> Result<String, PDRouterError> {
|
||||||
|
// Create EngineInfo for the new prefill server
|
||||||
|
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
|
||||||
|
|
||||||
|
// Wait for the new server to be healthy
|
||||||
|
crate::router::Router::wait_for_healthy_workers(
|
||||||
|
&[url.clone()],
|
||||||
|
self.timeout_secs,
|
||||||
|
self.interval_secs,
|
||||||
|
)
|
||||||
|
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||||
|
|
||||||
|
// Add to prefill workers list
|
||||||
|
let mut workers = self
|
||||||
|
.prefill_workers
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PDRouterError::LockError {
|
||||||
|
operation: "prefill_workers write".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Check if already exists
|
||||||
|
if workers.iter().any(|w| w.url == url) {
|
||||||
|
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||||
|
}
|
||||||
|
|
||||||
|
workers.push(engine_info);
|
||||||
|
|
||||||
|
// Initialize load tracking
|
||||||
|
self.load_tracking
|
||||||
|
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||||
|
|
||||||
|
// Add to cache tree if using cache-aware policy
|
||||||
|
if let Some(ref tree) = self.prefill_tree {
|
||||||
|
tree.lock().unwrap().insert("", &url);
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Added prefill server: {}", url);
|
||||||
|
Ok(format!("Successfully added prefill server: {}", url))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||||
|
// Create EngineInfo for the new decode server
|
||||||
|
let engine_info = EngineInfo::new_decode(url.clone());
|
||||||
|
|
||||||
|
// Wait for the new server to be healthy
|
||||||
|
crate::router::Router::wait_for_healthy_workers(
|
||||||
|
&[url.clone()],
|
||||||
|
self.timeout_secs,
|
||||||
|
self.interval_secs,
|
||||||
|
)
|
||||||
|
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||||
|
|
||||||
|
// Add to decode workers list
|
||||||
|
let mut workers = self
|
||||||
|
.decode_workers
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PDRouterError::LockError {
|
||||||
|
operation: "decode_workers write".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Check if already exists
|
||||||
|
if workers.iter().any(|w| w.url == url) {
|
||||||
|
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||||
|
}
|
||||||
|
|
||||||
|
workers.push(engine_info);
|
||||||
|
|
||||||
|
// Initialize load tracking
|
||||||
|
self.load_tracking
|
||||||
|
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||||
|
|
||||||
|
info!("Added decode server: {}", url);
|
||||||
|
Ok(format!("Successfully added decode server: {}", url))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_prefill_server(&self, url: &str) -> Result<String, PDRouterError> {
|
||||||
|
let mut workers = self
|
||||||
|
.prefill_workers
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PDRouterError::LockError {
|
||||||
|
operation: "prefill_workers write".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Find and remove the server
|
||||||
|
let initial_len = workers.len();
|
||||||
|
workers.retain(|w| w.url != url);
|
||||||
|
|
||||||
|
if workers.len() == initial_len {
|
||||||
|
return Err(PDRouterError::WorkerNotFound {
|
||||||
|
url: url.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from load tracking
|
||||||
|
self.load_tracking.remove(url);
|
||||||
|
|
||||||
|
// Remove from cache tree if using cache-aware policy
|
||||||
|
if let Some(ref tree) = self.prefill_tree {
|
||||||
|
// Note: Tree doesn't have a remove method, so we rebuild it
|
||||||
|
let mut tree_guard = tree.lock().unwrap();
|
||||||
|
*tree_guard = Tree::new();
|
||||||
|
for worker in workers.iter() {
|
||||||
|
tree_guard.insert("", &worker.url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Removed prefill server: {}", url);
|
||||||
|
Ok(format!("Successfully removed prefill server: {}", url))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_decode_server(&self, url: &str) -> Result<String, PDRouterError> {
|
||||||
|
let mut workers = self
|
||||||
|
.decode_workers
|
||||||
|
.write()
|
||||||
|
.map_err(|_| PDRouterError::LockError {
|
||||||
|
operation: "decode_workers write".to_string(),
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Find and remove the server
|
||||||
|
let initial_len = workers.len();
|
||||||
|
workers.retain(|w| w.url != url);
|
||||||
|
|
||||||
|
if workers.len() == initial_len {
|
||||||
|
return Err(PDRouterError::WorkerNotFound {
|
||||||
|
url: url.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from load tracking
|
||||||
|
self.load_tracking.remove(url);
|
||||||
|
|
||||||
|
info!("Removed decode server: {}", url);
|
||||||
|
Ok(format!("Successfully removed decode server: {}", url))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
prefill_urls: Vec<(String, Option<u16>)>,
|
prefill_urls: Vec<(String, Option<u16>)>,
|
||||||
|
|||||||
@@ -3,6 +3,31 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// Custom error type for PD router operations
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum PDRouterError {
|
||||||
|
#[error("Worker already exists: {url}")]
|
||||||
|
WorkerAlreadyExists { url: String },
|
||||||
|
|
||||||
|
#[error("Worker not found: {url}")]
|
||||||
|
WorkerNotFound { url: String },
|
||||||
|
|
||||||
|
#[error("Lock acquisition failed: {operation}")]
|
||||||
|
LockError { operation: String },
|
||||||
|
|
||||||
|
#[error("Health check failed for worker: {url}")]
|
||||||
|
HealthCheckFailed { url: String },
|
||||||
|
|
||||||
|
#[error("Invalid worker configuration: {reason}")]
|
||||||
|
InvalidConfiguration { reason: String },
|
||||||
|
|
||||||
|
#[error("Network error: {message}")]
|
||||||
|
NetworkError { message: String },
|
||||||
|
|
||||||
|
#[error("Timeout waiting for worker: {url}")]
|
||||||
|
Timeout { url: String },
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum EngineType {
|
pub enum EngineType {
|
||||||
Prefill,
|
Prefill,
|
||||||
|
|||||||
@@ -1045,6 +1045,55 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Add a worker with PD mode support
|
||||||
|
pub async fn add_pd_worker(
|
||||||
|
&self,
|
||||||
|
worker_url: &str,
|
||||||
|
pod_type: crate::service_discovery::PodType,
|
||||||
|
bootstrap_port: Option<u16>,
|
||||||
|
) -> Result<String, String> {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => match pod_type {
|
||||||
|
crate::service_discovery::PodType::Prefill => pd_router
|
||||||
|
.add_prefill_server(worker_url.to_string(), bootstrap_port)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
crate::service_discovery::PodType::Decode => pd_router
|
||||||
|
.add_decode_server(worker_url.to_string())
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
crate::service_discovery::PodType::Regular => {
|
||||||
|
Err("Regular pod type not supported in PD mode".to_string())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => Err("add_pd_worker only supported in PD mode".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a worker with PD mode support
|
||||||
|
pub async fn remove_pd_worker(
|
||||||
|
&self,
|
||||||
|
worker_url: &str,
|
||||||
|
pod_type: crate::service_discovery::PodType,
|
||||||
|
) -> Result<String, String> {
|
||||||
|
match self {
|
||||||
|
Router::PrefillDecode { pd_router } => match pod_type {
|
||||||
|
crate::service_discovery::PodType::Prefill => pd_router
|
||||||
|
.remove_prefill_server(worker_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
crate::service_discovery::PodType::Decode => pd_router
|
||||||
|
.remove_decode_server(worker_url)
|
||||||
|
.await
|
||||||
|
.map_err(|e| e.to_string()),
|
||||||
|
crate::service_discovery::PodType::Regular => {
|
||||||
|
Err("Regular pod type not supported in PD mode".to_string())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => Err("remove_pd_worker only supported in PD mode".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||||
@@ -1174,3 +1223,108 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::service_discovery::PodType;
|
||||||
|
|
||||||
|
fn create_test_regular_router() -> Router {
|
||||||
|
Router::Random {
|
||||||
|
worker_urls: Arc::new(RwLock::new(vec![
|
||||||
|
"http://worker1:8080".to_string(),
|
||||||
|
"http://worker2:8080".to_string(),
|
||||||
|
])),
|
||||||
|
timeout_secs: 5,
|
||||||
|
interval_secs: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_router_get_worker_urls_regular() {
|
||||||
|
let router = create_test_regular_router();
|
||||||
|
let worker_urls = router.get_worker_urls();
|
||||||
|
let urls = worker_urls.read().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(urls.len(), 2);
|
||||||
|
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||||
|
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn test_router_get_worker_urls_pd_mode() {
|
||||||
|
// // For PD mode, get_worker_urls returns empty list
|
||||||
|
// // Note: PDRouter::new requires health checks which fail in tests
|
||||||
|
// // This test would need a mock server or different test setup
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_add_pd_worker_with_regular_router() {
|
||||||
|
let router = create_test_regular_router();
|
||||||
|
|
||||||
|
let result = router
|
||||||
|
.add_pd_worker("http://new-worker:8080", PodType::Prefill, Some(8081))
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result
|
||||||
|
.unwrap_err()
|
||||||
|
.contains("add_pd_worker only supported in PD mode"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_remove_pd_worker_with_regular_router() {
|
||||||
|
let router = create_test_regular_router();
|
||||||
|
|
||||||
|
let result = router
|
||||||
|
.remove_pd_worker("http://worker:8080", PodType::Decode)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result
|
||||||
|
.unwrap_err()
|
||||||
|
.contains("remove_pd_worker only supported in PD mode"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[tokio::test]
|
||||||
|
// async fn test_add_pd_worker_with_pd_router_regular_type() {
|
||||||
|
// // Note: PDRouter::new requires health checks which fail in tests
|
||||||
|
// // This test would need a mock server or different test setup
|
||||||
|
// }
|
||||||
|
|
||||||
|
// #[tokio::test]
|
||||||
|
// async fn test_remove_pd_worker_with_pd_router_regular_type() {
|
||||||
|
// // Note: PDRouter::new requires health checks which fail in tests
|
||||||
|
// // This test would need a mock server or different test setup
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_select_first_worker_regular() {
|
||||||
|
let router = create_test_regular_router();
|
||||||
|
let result = router.select_first_worker();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
assert_eq!(result.unwrap(), "http://worker1:8080");
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn test_select_first_worker_pd_mode() {
|
||||||
|
// // Note: PDRouter::new requires health checks which fail in tests
|
||||||
|
// // This test would need a mock server or different test setup
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wait_for_healthy_workers_empty_list() {
|
||||||
|
let result = Router::wait_for_healthy_workers(&[], 1, 1);
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wait_for_healthy_workers_invalid_urls() {
|
||||||
|
// This test will timeout quickly since the URLs are invalid
|
||||||
|
let result =
|
||||||
|
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("Timeout"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -24,6 +24,12 @@ pub struct ServiceDiscoveryConfig {
|
|||||||
pub check_interval: Duration,
|
pub check_interval: Duration,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
pub namespace: Option<String>,
|
pub namespace: Option<String>,
|
||||||
|
// PD mode specific configuration
|
||||||
|
pub pd_mode: bool,
|
||||||
|
pub prefill_selector: HashMap<String, String>,
|
||||||
|
pub decode_selector: HashMap<String, String>,
|
||||||
|
// Bootstrap port annotation specific to mooncake implementation
|
||||||
|
pub bootstrap_port_annotation: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ServiceDiscoveryConfig {
|
impl Default for ServiceDiscoveryConfig {
|
||||||
@@ -32,12 +38,25 @@ impl Default for ServiceDiscoveryConfig {
|
|||||||
enabled: false,
|
enabled: false,
|
||||||
selector: HashMap::new(),
|
selector: HashMap::new(),
|
||||||
check_interval: Duration::from_secs(60),
|
check_interval: Duration::from_secs(60),
|
||||||
port: 80, // Default port to connect to pods
|
port: 8000, // Standard port for modern services
|
||||||
namespace: None, // None means watch all namespaces
|
namespace: None, // None means watch all namespaces
|
||||||
|
// PD mode defaults
|
||||||
|
pd_mode: false,
|
||||||
|
prefill_selector: HashMap::new(),
|
||||||
|
decode_selector: HashMap::new(),
|
||||||
|
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Pod type for PD mode service discovery
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub enum PodType {
|
||||||
|
Prefill,
|
||||||
|
Decode,
|
||||||
|
Regular,
|
||||||
|
}
|
||||||
|
|
||||||
/// Represents a Kubernetes pod's information used for worker management
|
/// Represents a Kubernetes pod's information used for worker management
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
pub struct PodInfo {
|
pub struct PodInfo {
|
||||||
@@ -45,10 +64,47 @@ pub struct PodInfo {
|
|||||||
pub ip: String,
|
pub ip: String,
|
||||||
pub status: String,
|
pub status: String,
|
||||||
pub is_ready: bool,
|
pub is_ready: bool,
|
||||||
|
pub pod_type: Option<PodType>,
|
||||||
|
pub bootstrap_port: Option<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PodInfo {
|
impl PodInfo {
|
||||||
pub fn from_pod(pod: &Pod) -> Option<Self> {
|
/// Check if a pod matches any of the given selectors
|
||||||
|
fn matches_selector(pod: &Pod, selector: &HashMap<String, String>) -> bool {
|
||||||
|
if selector.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
pod.metadata.labels.as_ref().map_or(false, |labels| {
|
||||||
|
selector
|
||||||
|
.iter()
|
||||||
|
.all(|(k, v)| labels.get(k).map_or(false, |label_value| label_value == v))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a pod should be included in service discovery
|
||||||
|
pub fn should_include(pod: &Pod, config: &ServiceDiscoveryConfig) -> bool {
|
||||||
|
if config.pd_mode {
|
||||||
|
// In PD mode, at least one selector must be non-empty
|
||||||
|
if config.prefill_selector.is_empty() && config.decode_selector.is_empty() {
|
||||||
|
warn!("PD mode enabled but both prefill_selector and decode_selector are empty");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// In PD mode, pod must match either prefill or decode selector
|
||||||
|
Self::matches_selector(pod, &config.prefill_selector)
|
||||||
|
|| Self::matches_selector(pod, &config.decode_selector)
|
||||||
|
} else {
|
||||||
|
// In regular mode, pod must match the general selector
|
||||||
|
if config.selector.is_empty() {
|
||||||
|
warn!("Regular mode enabled but selector is empty");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
Self::matches_selector(pod, &config.selector)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Unified PodInfo creation with optional PD configuration
|
||||||
|
pub fn from_pod(pod: &Pod, config: Option<&ServiceDiscoveryConfig>) -> Option<Self> {
|
||||||
let name = pod.metadata.name.clone()?;
|
let name = pod.metadata.name.clone()?;
|
||||||
let status = pod.status.clone()?;
|
let status = pod.status.clone()?;
|
||||||
let pod_ip = status.pod_ip?;
|
let pod_ip = status.pod_ip?;
|
||||||
@@ -63,11 +119,47 @@ impl PodInfo {
|
|||||||
|
|
||||||
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
|
let pod_status = status.phase.unwrap_or_else(|| "Unknown".to_string());
|
||||||
|
|
||||||
|
// Determine pod type based on labels if config is provided and in PD mode
|
||||||
|
let pod_type = if let Some(config) = config {
|
||||||
|
if config.pd_mode {
|
||||||
|
// Use simplified helper methods for cleaner logic
|
||||||
|
if Self::matches_selector(pod, &config.prefill_selector) {
|
||||||
|
Some(PodType::Prefill)
|
||||||
|
} else if Self::matches_selector(pod, &config.decode_selector) {
|
||||||
|
Some(PodType::Decode)
|
||||||
|
} else {
|
||||||
|
Some(PodType::Regular)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Some(PodType::Regular)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// No config provided, default to None (for backwards compatibility)
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
// Extract bootstrap port from annotations for prefill pods
|
||||||
|
let bootstrap_port = if matches!(pod_type, Some(PodType::Prefill)) {
|
||||||
|
if let Some(config) = config {
|
||||||
|
pod.metadata
|
||||||
|
.annotations
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|annotations| annotations.get(&config.bootstrap_port_annotation))
|
||||||
|
.and_then(|port_str| port_str.parse::<u16>().ok())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
Some(PodInfo {
|
Some(PodInfo {
|
||||||
name,
|
name,
|
||||||
ip: pod_ip,
|
ip: pod_ip,
|
||||||
status: pod_status,
|
status: pod_status,
|
||||||
is_ready,
|
is_ready,
|
||||||
|
pod_type,
|
||||||
|
bootstrap_port,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +192,27 @@ pub async fn start_service_discovery(
|
|||||||
// Initialize Kubernetes client
|
// Initialize Kubernetes client
|
||||||
let client = Client::try_default().await?;
|
let client = Client::try_default().await?;
|
||||||
|
|
||||||
// Construct label selector string from map
|
// Log the appropriate selectors based on mode
|
||||||
|
if config.pd_mode {
|
||||||
|
let prefill_selector = config
|
||||||
|
.prefill_selector
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("{}={}", k, v))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",");
|
||||||
|
|
||||||
|
let decode_selector = config
|
||||||
|
.decode_selector
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| format!("{}={}", k, v))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(",");
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Starting Kubernetes service discovery in PD mode with prefill_selector: '{}', decode_selector: '{}'",
|
||||||
|
prefill_selector, decode_selector
|
||||||
|
);
|
||||||
|
} else {
|
||||||
let label_selector = config
|
let label_selector = config
|
||||||
.selector
|
.selector
|
||||||
.iter()
|
.iter()
|
||||||
@@ -109,9 +221,10 @@ pub async fn start_service_discovery(
|
|||||||
.join(",");
|
.join(",");
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Starting Kubernetes service discovery with selector: {}",
|
"Starting Kubernetes service discovery with selector: '{}'",
|
||||||
label_selector
|
label_selector
|
||||||
);
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Create the task that will run in the background
|
// Create the task that will run in the background
|
||||||
let handle = task::spawn(async move {
|
let handle = task::spawn(async move {
|
||||||
@@ -127,33 +240,30 @@ pub async fn start_service_discovery(
|
|||||||
|
|
||||||
info!("Kubernetes service discovery initialized successfully");
|
info!("Kubernetes service discovery initialized successfully");
|
||||||
|
|
||||||
// Create an Arc for the selector map
|
// Create Arcs for configuration data
|
||||||
let selector = Arc::new(config.selector);
|
let config_arc = Arc::new(config.clone());
|
||||||
let port = config.port;
|
let port = config.port;
|
||||||
|
|
||||||
|
let mut retry_delay = Duration::from_secs(1);
|
||||||
|
const MAX_RETRY_DELAY: Duration = Duration::from_secs(300); // 5 minutes max
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
// Create a watcher with the proper parameters according to the kube-rs API
|
// Create a watcher with the proper parameters according to the kube-rs API
|
||||||
let watcher_config = Config::default();
|
let watcher_config = Config::default();
|
||||||
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
|
let watcher_stream = watcher(pods.clone(), watcher_config).applied_objects();
|
||||||
|
|
||||||
// Clone Arcs for the closures
|
// Clone Arcs for the closures
|
||||||
let selector_clone = Arc::clone(&selector);
|
let config_clone = Arc::clone(&config_arc);
|
||||||
let tracked_pods_clone = Arc::clone(&tracked_pods);
|
let tracked_pods_clone = Arc::clone(&tracked_pods);
|
||||||
|
|
||||||
// Apply label selector filter separately since we can't do it directly with the watcher anymore
|
// Simplified label selector filter using helper method
|
||||||
let filtered_stream = watcher_stream.filter_map(move |obj_res| {
|
let filtered_stream = watcher_stream.filter_map(move |obj_res| {
|
||||||
let selector_inner = Arc::clone(&selector_clone);
|
let config_inner = Arc::clone(&config_clone);
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
match obj_res {
|
match obj_res {
|
||||||
Ok(pod) => {
|
Ok(pod) => {
|
||||||
// Only process pods matching our label selector
|
if PodInfo::should_include(&pod, &config_inner) {
|
||||||
if pod.metadata.labels.as_ref().map_or(false, |labels| {
|
|
||||||
// Check if the pod has all the labels from our selector
|
|
||||||
selector_inner.iter().all(|(k, v)| {
|
|
||||||
labels.get(k).map_or(false, |label_value| label_value == v)
|
|
||||||
})
|
|
||||||
}) {
|
|
||||||
Some(Ok(pod))
|
Some(Ok(pod))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -167,24 +277,35 @@ pub async fn start_service_discovery(
|
|||||||
// Clone again for the next closure
|
// Clone again for the next closure
|
||||||
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
|
let tracked_pods_clone2 = Arc::clone(&tracked_pods_clone);
|
||||||
let router_clone = Arc::clone(&router);
|
let router_clone = Arc::clone(&router);
|
||||||
|
let config_clone2 = Arc::clone(&config_arc);
|
||||||
|
|
||||||
match filtered_stream
|
match filtered_stream
|
||||||
.try_for_each(move |pod| {
|
.try_for_each(move |pod| {
|
||||||
let tracked_pods_inner = Arc::clone(&tracked_pods_clone2);
|
let tracked_pods_inner = Arc::clone(&tracked_pods_clone2);
|
||||||
let router_inner = Arc::clone(&router_clone);
|
let router_inner = Arc::clone(&router_clone);
|
||||||
|
let config_inner = Arc::clone(&config_clone2);
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
if let Some(pod_info) = PodInfo::from_pod(&pod) {
|
let pod_info = PodInfo::from_pod(&pod, Some(&config_inner));
|
||||||
|
|
||||||
|
if let Some(pod_info) = pod_info {
|
||||||
if pod.metadata.deletion_timestamp.is_some() {
|
if pod.metadata.deletion_timestamp.is_some() {
|
||||||
handle_pod_deletion(
|
handle_pod_deletion(
|
||||||
&pod_info,
|
&pod_info,
|
||||||
tracked_pods_inner,
|
tracked_pods_inner,
|
||||||
router_inner,
|
router_inner,
|
||||||
port,
|
port,
|
||||||
|
config_inner.pd_mode,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
} else {
|
} else {
|
||||||
handle_pod_event(&pod_info, tracked_pods_inner, router_inner, port)
|
handle_pod_event(
|
||||||
|
&pod_info,
|
||||||
|
tracked_pods_inner,
|
||||||
|
router_inner,
|
||||||
|
port,
|
||||||
|
config_inner.pd_mode,
|
||||||
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -193,20 +314,29 @@ pub async fn start_service_discovery(
|
|||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(_) => {}
|
Ok(_) => {
|
||||||
|
// Reset retry delay on success
|
||||||
|
retry_delay = Duration::from_secs(1);
|
||||||
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!("Error in Kubernetes watcher: {}", err);
|
error!("Error in Kubernetes watcher: {}", err);
|
||||||
// Wait a bit before retrying
|
warn!(
|
||||||
time::sleep(Duration::from_secs(5)).await;
|
"Retrying in {} seconds with exponential backoff",
|
||||||
|
retry_delay.as_secs()
|
||||||
|
);
|
||||||
|
time::sleep(retry_delay).await;
|
||||||
|
|
||||||
|
// Exponential backoff with jitter
|
||||||
|
retry_delay = std::cmp::min(retry_delay * 2, MAX_RETRY_DELAY);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the watcher exits for some reason, wait a bit before restarting
|
// If the watcher exits for some reason, wait a bit before restarting
|
||||||
warn!(
|
warn!(
|
||||||
"Kubernetes watcher exited, restarting in {} seconds",
|
"Kubernetes watcher exited, restarting in {} seconds",
|
||||||
config.check_interval.as_secs()
|
config_arc.check_interval.as_secs()
|
||||||
);
|
);
|
||||||
time::sleep(config.check_interval).await;
|
time::sleep(config_arc.check_interval).await;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -218,34 +348,64 @@ async fn handle_pod_event(
|
|||||||
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
||||||
router: Arc<Router>,
|
router: Arc<Router>,
|
||||||
port: u16,
|
port: u16,
|
||||||
|
pd_mode: bool,
|
||||||
) {
|
) {
|
||||||
let worker_url = pod_info.worker_url(port);
|
let worker_url = pod_info.worker_url(port);
|
||||||
|
|
||||||
// Check if pod is already tracked
|
// If pod is healthy, try to add it (with atomic check-and-insert)
|
||||||
let already_tracked = {
|
if pod_info.is_healthy() {
|
||||||
let tracker = tracked_pods.lock().unwrap();
|
// Atomic check-and-insert to prevent race conditions
|
||||||
tracker.contains(pod_info)
|
let should_add = {
|
||||||
|
let mut tracker = match tracked_pods.lock() {
|
||||||
|
Ok(tracker) => tracker,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to acquire tracked_pods lock: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// If pod is healthy and not already tracked, add it
|
if tracker.contains(pod_info) {
|
||||||
if pod_info.is_healthy() {
|
false // Already tracked
|
||||||
if !already_tracked {
|
} else {
|
||||||
info!(
|
// Reserve the spot to prevent other threads from adding the same pod
|
||||||
"Healthy pod found: {}. Adding worker: {}",
|
|
||||||
pod_info.name, worker_url
|
|
||||||
);
|
|
||||||
match router.add_worker(&worker_url).await {
|
|
||||||
Ok(msg) => {
|
|
||||||
info!("Router add_worker: {}", msg);
|
|
||||||
let mut tracker = tracked_pods.lock().unwrap();
|
|
||||||
tracker.insert(pod_info.clone());
|
tracker.insert(pod_info.clone());
|
||||||
|
true
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if should_add {
|
||||||
|
info!(
|
||||||
|
"Healthy pod found: {} (type: {:?}). Adding worker: {}",
|
||||||
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = if pd_mode && pod_info.pod_type.is_some() {
|
||||||
|
// Use PD-aware worker management
|
||||||
|
if let Some(pod_type) = &pod_info.pod_type {
|
||||||
|
router
|
||||||
|
.add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port)
|
||||||
|
.await
|
||||||
|
} else {
|
||||||
|
Err("Pod type is None in PD mode".to_string())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fallback to regular worker management
|
||||||
|
router.add_worker(&worker_url).await
|
||||||
|
};
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(msg) => {
|
||||||
|
info!("Successfully added worker: {}", msg);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to add worker {} to router: {}", worker_url, e);
|
||||||
|
// Remove from tracking since addition failed
|
||||||
|
if let Ok(mut tracker) = tracked_pods.lock() {
|
||||||
|
tracker.remove(pod_info);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(e) => error!("Failed to add worker {} to router: {}", worker_url, e),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else if already_tracked {
|
|
||||||
// If pod was healthy before but not anymore, remove it
|
|
||||||
handle_pod_deletion(pod_info, tracked_pods, router, port).await;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -254,22 +414,47 @@ async fn handle_pod_deletion(
|
|||||||
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
||||||
router: Arc<Router>,
|
router: Arc<Router>,
|
||||||
port: u16,
|
port: u16,
|
||||||
|
pd_mode: bool,
|
||||||
) {
|
) {
|
||||||
let worker_url = pod_info.worker_url(port);
|
let worker_url = pod_info.worker_url(port);
|
||||||
let mut tracked = tracked_pods.lock().unwrap();
|
|
||||||
|
|
||||||
if tracked.remove(pod_info) {
|
let was_tracked = {
|
||||||
|
let mut tracked = match tracked_pods.lock() {
|
||||||
|
Ok(tracked) => tracked,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to acquire tracked_pods lock during deletion: {}", e);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
tracked.remove(pod_info)
|
||||||
|
};
|
||||||
|
|
||||||
|
if was_tracked {
|
||||||
info!(
|
info!(
|
||||||
"Pod deleted: {}. Removing worker: {}",
|
"Pod deleted: {} (type: {:?}). Removing worker: {}",
|
||||||
pod_info.name, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if pd_mode && pod_info.pod_type.is_some() {
|
||||||
|
// Use PD-aware worker removal
|
||||||
|
if let Some(pod_type) = &pod_info.pod_type {
|
||||||
|
if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await {
|
||||||
|
error!(
|
||||||
|
"Failed to remove PD worker {} from router: {}",
|
||||||
|
worker_url, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fallback to regular worker removal
|
||||||
router.remove_worker(&worker_url);
|
router.remove_worker(&worker_url);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// This case might occur if a pod is deleted before it was ever marked healthy and added.
|
// This case might occur if a pod is deleted before it was ever marked healthy and added.
|
||||||
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
|
// Or if the event is duplicated. No action needed on the router if it wasn't tracked (and thus not added).
|
||||||
debug!(
|
debug!(
|
||||||
"Pod deletion event for untracked/already removed pod: {}. Worker URL: {}",
|
"Pod deletion event for untracked/already removed pod: {} (type: {:?}). Worker URL: {}",
|
||||||
pod_info.name, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -325,6 +510,41 @@ mod tests {
|
|||||||
pod
|
pod
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to create a Pod with PD-specific labels and annotations
|
||||||
|
fn create_pd_k8s_pod(name: &str, ip: &str, pod_type: &str, bootstrap_port: Option<u16>) -> Pod {
|
||||||
|
let mut labels = std::collections::BTreeMap::new();
|
||||||
|
labels.insert("app".to_string(), "sglang".to_string());
|
||||||
|
labels.insert("component".to_string(), pod_type.to_string());
|
||||||
|
|
||||||
|
let mut annotations = std::collections::BTreeMap::new();
|
||||||
|
if let Some(port) = bootstrap_port {
|
||||||
|
annotations.insert("sglang.ai/bootstrap-port".to_string(), port.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Pod {
|
||||||
|
metadata: ObjectMeta {
|
||||||
|
name: Some(name.to_string()),
|
||||||
|
labels: Some(labels),
|
||||||
|
annotations: Some(annotations),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
spec: Some(PodSpec::default()),
|
||||||
|
status: Some(PodStatus {
|
||||||
|
pod_ip: Some(ip.to_string()),
|
||||||
|
phase: Some("Running".to_string()),
|
||||||
|
conditions: Some(vec![PodCondition {
|
||||||
|
type_: "Ready".to_string(),
|
||||||
|
status: "True".to_string(),
|
||||||
|
last_probe_time: None,
|
||||||
|
last_transition_time: None,
|
||||||
|
message: None,
|
||||||
|
reason: None,
|
||||||
|
}]),
|
||||||
|
..Default::default()
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Helper to create a Router instance for testing event handlers
|
// Helper to create a Router instance for testing event handlers
|
||||||
fn create_test_router() -> Arc<Router> {
|
fn create_test_router() -> Arc<Router> {
|
||||||
let worker_urls = Arc::new(RwLock::new(Vec::new()));
|
let worker_urls = Arc::new(RwLock::new(Vec::new()));
|
||||||
@@ -335,14 +555,80 @@ mod tests {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper to create a PD config for testing
|
||||||
|
fn create_pd_config() -> ServiceDiscoveryConfig {
|
||||||
|
let mut prefill_selector = HashMap::new();
|
||||||
|
prefill_selector.insert("app".to_string(), "sglang".to_string());
|
||||||
|
prefill_selector.insert("component".to_string(), "prefill".to_string());
|
||||||
|
|
||||||
|
let mut decode_selector = HashMap::new();
|
||||||
|
decode_selector.insert("app".to_string(), "sglang".to_string());
|
||||||
|
decode_selector.insert("component".to_string(), "decode".to_string());
|
||||||
|
|
||||||
|
ServiceDiscoveryConfig {
|
||||||
|
enabled: true,
|
||||||
|
selector: HashMap::new(),
|
||||||
|
check_interval: Duration::from_secs(60),
|
||||||
|
port: 8080,
|
||||||
|
namespace: None,
|
||||||
|
pd_mode: true,
|
||||||
|
prefill_selector,
|
||||||
|
decode_selector,
|
||||||
|
bootstrap_port_annotation: "sglang.ai/bootstrap-port".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_should_include() {
|
||||||
|
let config = create_pd_config();
|
||||||
|
|
||||||
|
// Test prefill pod should be included
|
||||||
|
let prefill_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
|
||||||
|
assert!(PodInfo::should_include(&prefill_pod, &config));
|
||||||
|
|
||||||
|
// Test decode pod should be included
|
||||||
|
let decode_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
|
||||||
|
assert!(PodInfo::should_include(&decode_pod, &config));
|
||||||
|
|
||||||
|
// Test unmatched pod should not be included
|
||||||
|
let unmatched_pod = create_pd_k8s_pod("other-pod", "10.0.0.3", "other", None);
|
||||||
|
assert!(!PodInfo::should_include(&unmatched_pod, &config));
|
||||||
|
|
||||||
|
// Test regular mode
|
||||||
|
let mut regular_config = ServiceDiscoveryConfig::default();
|
||||||
|
regular_config
|
||||||
|
.selector
|
||||||
|
.insert("app".to_string(), "sglang".to_string());
|
||||||
|
regular_config.pd_mode = false;
|
||||||
|
|
||||||
|
let regular_pod = create_pd_k8s_pod("worker-pod", "10.0.0.4", "worker", None);
|
||||||
|
assert!(PodInfo::should_include(®ular_pod, ®ular_config));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_service_discovery_config_default() {
|
fn test_service_discovery_config_default() {
|
||||||
let config = ServiceDiscoveryConfig::default();
|
let config = ServiceDiscoveryConfig::default();
|
||||||
assert!(!config.enabled);
|
assert!(!config.enabled);
|
||||||
assert!(config.selector.is_empty());
|
assert!(config.selector.is_empty());
|
||||||
assert_eq!(config.check_interval, Duration::from_secs(60));
|
assert_eq!(config.check_interval, Duration::from_secs(60));
|
||||||
assert_eq!(config.port, 80);
|
assert_eq!(config.port, 8000);
|
||||||
assert!(config.namespace.is_none());
|
assert!(config.namespace.is_none());
|
||||||
|
assert!(!config.pd_mode);
|
||||||
|
assert!(config.prefill_selector.is_empty());
|
||||||
|
assert!(config.decode_selector.is_empty());
|
||||||
|
assert_eq!(config.bootstrap_port_annotation, "sglang.ai/bootstrap-port");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_type_enum() {
|
||||||
|
// Test that PodType enum has expected variants
|
||||||
|
let prefill = PodType::Prefill;
|
||||||
|
let decode = PodType::Decode;
|
||||||
|
let regular = PodType::Regular;
|
||||||
|
|
||||||
|
assert_eq!(format!("{:?}", prefill), "Prefill");
|
||||||
|
assert_eq!(format!("{:?}", decode), "Decode");
|
||||||
|
assert_eq!(format!("{:?}", regular), "Regular");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -354,11 +640,85 @@ mod tests {
|
|||||||
Some("True"),
|
Some("True"),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let pod_info = PodInfo::from_pod(&k8s_pod).unwrap();
|
let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap();
|
||||||
assert_eq!(pod_info.name, "test-pod");
|
assert_eq!(pod_info.name, "test-pod");
|
||||||
assert_eq!(pod_info.ip, "10.0.0.1");
|
assert_eq!(pod_info.ip, "10.0.0.1");
|
||||||
assert_eq!(pod_info.status, "Running");
|
assert_eq!(pod_info.status, "Running");
|
||||||
assert!(pod_info.is_ready);
|
assert!(pod_info.is_ready);
|
||||||
|
assert!(pod_info.pod_type.is_none());
|
||||||
|
assert!(pod_info.bootstrap_port.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_from_pod_with_pd_config_prefill() {
|
||||||
|
let k8s_pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", Some(8081));
|
||||||
|
let config = create_pd_config();
|
||||||
|
|
||||||
|
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
||||||
|
assert_eq!(pod_info.name, "prefill-pod");
|
||||||
|
assert_eq!(pod_info.ip, "10.0.0.1");
|
||||||
|
assert_eq!(pod_info.status, "Running");
|
||||||
|
assert!(pod_info.is_ready);
|
||||||
|
assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
|
||||||
|
assert_eq!(pod_info.bootstrap_port, Some(8081));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_from_pod_with_pd_config_decode() {
|
||||||
|
let k8s_pod = create_pd_k8s_pod("decode-pod", "10.0.0.2", "decode", None);
|
||||||
|
let config = create_pd_config();
|
||||||
|
|
||||||
|
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
||||||
|
assert_eq!(pod_info.name, "decode-pod");
|
||||||
|
assert_eq!(pod_info.ip, "10.0.0.2");
|
||||||
|
assert_eq!(pod_info.status, "Running");
|
||||||
|
assert!(pod_info.is_ready);
|
||||||
|
assert_eq!(pod_info.pod_type, Some(PodType::Decode));
|
||||||
|
assert!(pod_info.bootstrap_port.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_from_pod_with_pd_config_regular_mode() {
|
||||||
|
let k8s_pod = create_pd_k8s_pod("regular-pod", "10.0.0.3", "worker", None);
|
||||||
|
let mut config = create_pd_config();
|
||||||
|
config.pd_mode = false; // Set to regular mode
|
||||||
|
|
||||||
|
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
||||||
|
assert_eq!(pod_info.name, "regular-pod");
|
||||||
|
assert_eq!(pod_info.ip, "10.0.0.3");
|
||||||
|
assert_eq!(pod_info.status, "Running");
|
||||||
|
assert!(pod_info.is_ready);
|
||||||
|
assert_eq!(pod_info.pod_type, Some(PodType::Regular));
|
||||||
|
assert!(pod_info.bootstrap_port.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_from_pod_with_pd_config_unmatched_labels() {
|
||||||
|
let k8s_pod = create_pd_k8s_pod("unknown-pod", "10.0.0.4", "unknown", None);
|
||||||
|
let config = create_pd_config();
|
||||||
|
|
||||||
|
let pod_info = PodInfo::from_pod(&k8s_pod, Some(&config)).unwrap();
|
||||||
|
assert_eq!(pod_info.name, "unknown-pod");
|
||||||
|
assert_eq!(pod_info.ip, "10.0.0.4");
|
||||||
|
assert_eq!(pod_info.status, "Running");
|
||||||
|
assert!(pod_info.is_ready);
|
||||||
|
assert_eq!(pod_info.pod_type, Some(PodType::Regular));
|
||||||
|
assert!(pod_info.bootstrap_port.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_from_pod_with_pd_config_invalid_bootstrap_port() {
|
||||||
|
let mut pod = create_pd_k8s_pod("prefill-pod", "10.0.0.1", "prefill", None);
|
||||||
|
// Add invalid bootstrap port annotation
|
||||||
|
pod.metadata.annotations.as_mut().unwrap().insert(
|
||||||
|
"sglang.ai/bootstrap-port".to_string(),
|
||||||
|
"invalid".to_string(),
|
||||||
|
);
|
||||||
|
let config = create_pd_config();
|
||||||
|
|
||||||
|
let pod_info = PodInfo::from_pod(&pod, Some(&config)).unwrap();
|
||||||
|
assert_eq!(pod_info.pod_type, Some(PodType::Prefill));
|
||||||
|
assert!(pod_info.bootstrap_port.is_none()); // Should be None for invalid port
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -370,7 +730,7 @@ mod tests {
|
|||||||
Some("False"),
|
Some("False"),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let pod_info = PodInfo::from_pod(&k8s_pod).unwrap();
|
let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap();
|
||||||
assert!(!pod_info.is_ready);
|
assert!(!pod_info.is_ready);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -383,26 +743,26 @@ mod tests {
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
let pod_info = PodInfo::from_pod(&k8s_pod).unwrap();
|
let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap();
|
||||||
assert!(!pod_info.is_ready);
|
assert!(!pod_info.is_ready);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pod_info_from_pod_missing_name() {
|
fn test_pod_info_from_pod_missing_name() {
|
||||||
let k8s_pod = create_k8s_pod(None, Some("10.0.0.1"), Some("Running"), Some("True"), None);
|
let k8s_pod = create_k8s_pod(None, Some("10.0.0.1"), Some("Running"), Some("True"), None);
|
||||||
assert!(PodInfo::from_pod(&k8s_pod).is_none());
|
assert!(PodInfo::from_pod(&k8s_pod, None).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pod_info_from_pod_missing_ip() {
|
fn test_pod_info_from_pod_missing_ip() {
|
||||||
let k8s_pod = create_k8s_pod(Some("test-pod"), None, Some("Running"), Some("True"), None);
|
let k8s_pod = create_k8s_pod(Some("test-pod"), None, Some("Running"), Some("True"), None);
|
||||||
assert!(PodInfo::from_pod(&k8s_pod).is_none());
|
assert!(PodInfo::from_pod(&k8s_pod, None).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pod_info_from_pod_missing_status_phase() {
|
fn test_pod_info_from_pod_missing_status_phase() {
|
||||||
let k8s_pod = create_k8s_pod(Some("test-pod"), Some("10.0.0.1"), None, Some("True"), None);
|
let k8s_pod = create_k8s_pod(Some("test-pod"), Some("10.0.0.1"), None, Some("True"), None);
|
||||||
let pod_info = PodInfo::from_pod(&k8s_pod).unwrap();
|
let pod_info = PodInfo::from_pod(&k8s_pod, None).unwrap();
|
||||||
assert_eq!(pod_info.status, "Unknown");
|
assert_eq!(pod_info.status, "Unknown");
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -410,7 +770,7 @@ mod tests {
|
|||||||
fn test_pod_info_from_pod_no_status_object() {
|
fn test_pod_info_from_pod_no_status_object() {
|
||||||
let mut k8s_pod = create_k8s_pod(Some("test-pod"), None, None, None, None);
|
let mut k8s_pod = create_k8s_pod(Some("test-pod"), None, None, None, None);
|
||||||
k8s_pod.status = None;
|
k8s_pod.status = None;
|
||||||
assert!(PodInfo::from_pod(&k8s_pod).is_none());
|
assert!(PodInfo::from_pod(&k8s_pod, None).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -420,6 +780,8 @@ mod tests {
|
|||||||
ip: "1.1.1.1".into(),
|
ip: "1.1.1.1".into(),
|
||||||
status: "Running".into(),
|
status: "Running".into(),
|
||||||
is_ready: true,
|
is_ready: true,
|
||||||
|
pod_type: None,
|
||||||
|
bootstrap_port: None,
|
||||||
};
|
};
|
||||||
assert!(healthy_pod.is_healthy());
|
assert!(healthy_pod.is_healthy());
|
||||||
|
|
||||||
@@ -428,6 +790,8 @@ mod tests {
|
|||||||
ip: "1.1.1.2".into(),
|
ip: "1.1.1.2".into(),
|
||||||
status: "Running".into(),
|
status: "Running".into(),
|
||||||
is_ready: false,
|
is_ready: false,
|
||||||
|
pod_type: None,
|
||||||
|
bootstrap_port: None,
|
||||||
};
|
};
|
||||||
assert!(!not_ready_pod.is_healthy());
|
assert!(!not_ready_pod.is_healthy());
|
||||||
|
|
||||||
@@ -436,6 +800,8 @@ mod tests {
|
|||||||
ip: "1.1.1.3".into(),
|
ip: "1.1.1.3".into(),
|
||||||
status: "Pending".into(),
|
status: "Pending".into(),
|
||||||
is_ready: true,
|
is_ready: true,
|
||||||
|
pod_type: None,
|
||||||
|
bootstrap_port: None,
|
||||||
};
|
};
|
||||||
assert!(!not_running_pod.is_healthy());
|
assert!(!not_running_pod.is_healthy());
|
||||||
}
|
}
|
||||||
@@ -447,10 +813,45 @@ mod tests {
|
|||||||
ip: "1.2.3.4".into(),
|
ip: "1.2.3.4".into(),
|
||||||
status: "Running".into(),
|
status: "Running".into(),
|
||||||
is_ready: true,
|
is_ready: true,
|
||||||
|
pod_type: None,
|
||||||
|
bootstrap_port: None,
|
||||||
};
|
};
|
||||||
assert_eq!(pod_info.worker_url(8080), "http://1.2.3.4:8080");
|
assert_eq!(pod_info.worker_url(8080), "http://1.2.3.4:8080");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_pod_info_equality_with_pod_type() {
|
||||||
|
let pod1 = PodInfo {
|
||||||
|
name: "pod1".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Prefill),
|
||||||
|
bootstrap_port: Some(8081),
|
||||||
|
};
|
||||||
|
|
||||||
|
let pod2 = PodInfo {
|
||||||
|
name: "pod1".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Prefill),
|
||||||
|
bootstrap_port: Some(8081),
|
||||||
|
};
|
||||||
|
|
||||||
|
let pod3 = PodInfo {
|
||||||
|
name: "pod1".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Decode),
|
||||||
|
bootstrap_port: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(pod1, pod2);
|
||||||
|
assert_ne!(pod1, pod3);
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_handle_pod_event_add_unhealthy_pod() {
|
async fn test_handle_pod_event_add_unhealthy_pod() {
|
||||||
let router = create_test_router();
|
let router = create_test_router();
|
||||||
@@ -460,6 +861,8 @@ mod tests {
|
|||||||
ip: "1.2.3.4".into(),
|
ip: "1.2.3.4".into(),
|
||||||
status: "Pending".into(),
|
status: "Pending".into(),
|
||||||
is_ready: false,
|
is_ready: false,
|
||||||
|
pod_type: None,
|
||||||
|
bootstrap_port: None,
|
||||||
};
|
};
|
||||||
let port = 8080u16;
|
let port = 8080u16;
|
||||||
|
|
||||||
@@ -468,6 +871,7 @@ mod tests {
|
|||||||
Arc::clone(&tracked_pods),
|
Arc::clone(&tracked_pods),
|
||||||
Arc::clone(&router),
|
Arc::clone(&router),
|
||||||
port,
|
port,
|
||||||
|
false, // pd_mode = false
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
@@ -488,6 +892,8 @@ mod tests {
|
|||||||
ip: "1.2.3.4".into(),
|
ip: "1.2.3.4".into(),
|
||||||
status: "Running".into(),
|
status: "Running".into(),
|
||||||
is_ready: true,
|
is_ready: true,
|
||||||
|
pod_type: None,
|
||||||
|
bootstrap_port: None,
|
||||||
};
|
};
|
||||||
let port = 8080u16;
|
let port = 8080u16;
|
||||||
|
|
||||||
@@ -496,10 +902,221 @@ mod tests {
|
|||||||
Arc::clone(&tracked_pods),
|
Arc::clone(&tracked_pods),
|
||||||
Arc::clone(&router),
|
Arc::clone(&router),
|
||||||
port,
|
port,
|
||||||
|
false, // pd_mode = false
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
assert!(tracked_pods.lock().unwrap().is_empty());
|
assert!(tracked_pods.lock().unwrap().is_empty());
|
||||||
assert!(router.get_worker_urls().read().unwrap().is_empty());
|
assert!(router.get_worker_urls().read().unwrap().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_pd_pod_event_prefill_pod() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "prefill-pod".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Prefill),
|
||||||
|
bootstrap_port: Some(8081),
|
||||||
|
};
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
// This test validates the structure but won't actually add workers since
|
||||||
|
// we're using a regular router instead of PD router
|
||||||
|
handle_pod_event(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
false, // pd_mode = false, so it should fallback to regular handling
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Pod should not be tracked since router.add_worker will fail for non-running server
|
||||||
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_pd_pod_event_decode_pod() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "decode-pod".into(),
|
||||||
|
ip: "1.2.3.5".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Decode),
|
||||||
|
bootstrap_port: None,
|
||||||
|
};
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
handle_pod_event(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
false, // pd_mode = false, so it should fallback to regular handling
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Pod should not be tracked since router.add_worker will fail for non-running server
|
||||||
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_pd_pod_deletion_tracked_pod() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "test-pod".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Prefill),
|
||||||
|
bootstrap_port: Some(8081),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add pod to tracked set first
|
||||||
|
{
|
||||||
|
let mut tracked = tracked_pods.lock().unwrap();
|
||||||
|
tracked.insert(pod_info.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
handle_pod_deletion(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
false, // pd_mode = false
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Pod should be removed from tracking
|
||||||
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_handle_pd_pod_deletion_untracked_pod() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "untracked-pod".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Decode),
|
||||||
|
bootstrap_port: None,
|
||||||
|
};
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
// Don't add pod to tracked set
|
||||||
|
|
||||||
|
handle_pod_deletion(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
true, // pd_mode = true
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Tracked set should remain empty
|
||||||
|
assert!(tracked_pods.lock().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unified_handler_regular_mode() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "regular-pod".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Regular),
|
||||||
|
bootstrap_port: None,
|
||||||
|
};
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
// Test that unified handler works for regular mode
|
||||||
|
handle_pod_event(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
false, // pd_mode = false
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Pod should not be tracked since router.add_worker will fail for non-running server
|
||||||
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unified_handler_pd_mode_with_prefill() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "prefill-pod".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Prefill),
|
||||||
|
bootstrap_port: Some(8081),
|
||||||
|
};
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
// Test that unified handler works for PD mode with prefill
|
||||||
|
handle_pod_event(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
true, // pd_mode = true
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Pod should not be tracked since router.add_pd_worker will fail for regular router
|
||||||
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_unified_handler_deletion_with_pd_mode() {
|
||||||
|
let router = create_test_router();
|
||||||
|
let tracked_pods = Arc::new(Mutex::new(HashSet::new()));
|
||||||
|
let pod_info = PodInfo {
|
||||||
|
name: "decode-pod".into(),
|
||||||
|
ip: "1.2.3.4".into(),
|
||||||
|
status: "Running".into(),
|
||||||
|
is_ready: true,
|
||||||
|
pod_type: Some(PodType::Decode),
|
||||||
|
bootstrap_port: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Add pod to tracked set first
|
||||||
|
{
|
||||||
|
let mut tracked = tracked_pods.lock().unwrap();
|
||||||
|
tracked.insert(pod_info.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let port = 8080u16;
|
||||||
|
|
||||||
|
// Test that unified handler works for deletion in PD mode
|
||||||
|
handle_pod_deletion(
|
||||||
|
&pod_info,
|
||||||
|
Arc::clone(&tracked_pods),
|
||||||
|
Arc::clone(&router),
|
||||||
|
port,
|
||||||
|
true, // pd_mode = true
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Pod should be removed from tracking
|
||||||
|
assert!(!tracked_pods.lock().unwrap().contains(&pod_info));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
//! - Phase 2: Bootstrap injection and request handling
|
//! - Phase 2: Bootstrap injection and request handling
|
||||||
//! - Phase 3: Cache-aware selection (when implemented)
|
//! - Phase 3: Cache-aware selection (when implemented)
|
||||||
//!
|
//!
|
||||||
//! Note: PD mode is enabled via the pd_disaggregated flag, not as a policy type.
|
//! Note: PD mode is enabled via the pd_disaggregation flag, not as a policy type.
|
||||||
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
//! The policy type (Random, PowerOfTwo, CacheAware) determines the selection algorithm within PD mode.
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
@@ -90,7 +90,7 @@ mod test_pd_routing {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pd_selection_policies() {
|
fn test_pd_selection_policies() {
|
||||||
// Test all PD selection policy variants
|
// Test all PD selection policy variants
|
||||||
// Note: These policies are only used when pd_disaggregated=true
|
// Note: These policies are only used when pd_disaggregation=true
|
||||||
let policies = vec![
|
let policies = vec![
|
||||||
PDSelectionPolicy::Random,
|
PDSelectionPolicy::Random,
|
||||||
PDSelectionPolicy::PowerOfTwo,
|
PDSelectionPolicy::PowerOfTwo,
|
||||||
@@ -122,7 +122,7 @@ mod test_pd_routing {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pd_router_configuration() {
|
fn test_pd_router_configuration() {
|
||||||
// Test PrefillDecodeConfig creation with various policies
|
// Test PrefillDecodeConfig creation with various policies
|
||||||
// This config is used when pd_disaggregated=true
|
// This config is used when pd_disaggregation=true
|
||||||
let configs = vec![
|
let configs = vec![
|
||||||
PolicyConfig::PrefillDecodeConfig {
|
PolicyConfig::PrefillDecodeConfig {
|
||||||
selection_policy: PDSelectionPolicy::Random,
|
selection_policy: PDSelectionPolicy::Random,
|
||||||
@@ -878,7 +878,7 @@ mod test_pd_routing {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_policy_type_to_pd_selection_policy_mapping() {
|
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||||
// Document the mapping from PolicyType to PDSelectionPolicy
|
// Document the mapping from PolicyType to PDSelectionPolicy
|
||||||
// This mapping happens in lib.rs when pd_disaggregated=true
|
// This mapping happens in lib.rs when pd_disaggregation=true
|
||||||
|
|
||||||
// PolicyType::Random -> PDSelectionPolicy::Random
|
// PolicyType::Random -> PDSelectionPolicy::Random
|
||||||
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
|
// PolicyType::PowerOfTwo -> PDSelectionPolicy::PowerOfTwo
|
||||||
|
|||||||
Reference in New Issue
Block a user