### What this PR does / why we need it?
Support an new load format: RFORK
For implementation details of this feature, please refer to #7441
### Does this PR introduce _any_ user-facing change?
add an new options for load-format: rfork
e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: Marck <1412354149@qq.com>
510 lines
17 KiB
Python
510 lines
17 KiB
Python
#
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
"""
|
|
Standalone rfork planner mock server used by vLLM rfork seed protocol tests.
|
|
|
|
Usage:
|
|
python examples/rfork/rfork_planner.py --host 0.0.0.0 --port 1223
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from collections import defaultdict
|
|
from collections.abc import Callable, Iterable, Mapping
|
|
from contextlib import asynccontextmanager
|
|
from dataclasses import dataclass
|
|
|
|
from fastapi import APIRouter, FastAPI, Request, Response, status
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Settings:
|
|
host: str = "0.0.0.0"
|
|
port: int = 1223
|
|
heartbeat_ttl_sec: int = 60
|
|
heartbeat_sweep_sec: int = 5
|
|
default_resource_points: int = 1
|
|
alloc_policy: str = "fifo"
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.port <= 0:
|
|
raise ValueError("port must be > 0")
|
|
if self.heartbeat_ttl_sec <= 0:
|
|
raise ValueError("heartbeat_ttl_sec must be > 0")
|
|
if self.heartbeat_sweep_sec <= 0:
|
|
raise ValueError("heartbeat_sweep_sec must be > 0")
|
|
if self.default_resource_points <= 0:
|
|
raise ValueError("default_resource_points must be > 0")
|
|
if self.alloc_policy not in {"fifo", "lru"}:
|
|
raise ValueError("alloc_policy must be one of: fifo, lru")
|
|
|
|
@staticmethod
|
|
def from_env() -> Settings:
|
|
return Settings(
|
|
host=os.getenv("RFORK_MOCK_HOST", "0.0.0.0"),
|
|
port=int(os.getenv("RFORK_MOCK_PORT", "1223")),
|
|
heartbeat_ttl_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_TTL_SEC", "60")),
|
|
heartbeat_sweep_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_SWEEP_SEC", "5")),
|
|
default_resource_points=int(os.getenv("RFORK_MOCK_DEFAULT_RESOURCE_POINTS", "1")),
|
|
alloc_policy=os.getenv("RFORK_MOCK_ALLOC_POLICY", "fifo").lower(),
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class SeedRecord:
|
|
seed_key: str
|
|
seed_ip: str
|
|
seed_port: int
|
|
seed_rank: int
|
|
last_heartbeat_ts: float
|
|
resource_total: int
|
|
resource_used: int = 0
|
|
|
|
@property
|
|
def identity(self) -> str:
|
|
return f"{self.seed_ip}:{self.seed_port}:{self.seed_rank}"
|
|
|
|
@property
|
|
def available_points(self) -> int:
|
|
return max(self.resource_total - self.resource_used, 0)
|
|
|
|
|
|
@dataclass
|
|
class LeaseRecord:
|
|
user_id: str
|
|
seed_key: str
|
|
seed_identity: str
|
|
allocated_points: int
|
|
leased_at: float
|
|
|
|
|
|
class Scheduler:
|
|
def __init__(self, alloc_policy: str = "fifo") -> None:
|
|
if alloc_policy not in {"fifo", "lru"}:
|
|
raise ValueError(f"unsupported alloc policy: {alloc_policy}")
|
|
self.alloc_policy = alloc_policy
|
|
|
|
def choose_seed(self, seeds: Iterable[SeedRecord]) -> SeedRecord | None:
|
|
candidates = [seed for seed in seeds if seed.available_points > 0]
|
|
if not candidates:
|
|
return None
|
|
|
|
if self.alloc_policy == "fifo":
|
|
return min(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
|
|
|
|
return max(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
|
|
|
|
|
|
class Store:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
heartbeat_ttl_sec: int,
|
|
default_resource_points: int,
|
|
scheduler: Scheduler,
|
|
time_fn: Callable[[], float] | None = None,
|
|
) -> None:
|
|
self._lock = threading.RLock()
|
|
self._seeds: dict[str, SeedRecord] = {}
|
|
self._seeds_by_key: dict[str, set[str]] = defaultdict(set)
|
|
self._leases: dict[str, LeaseRecord] = {}
|
|
self._heartbeat_ttl_sec = heartbeat_ttl_sec
|
|
self._default_resource_points = default_resource_points
|
|
self._scheduler = scheduler
|
|
self._time = time_fn or time.time
|
|
|
|
@staticmethod
|
|
def _seed_identity(seed_ip: str, seed_port: int, seed_rank: int) -> str:
|
|
return f"{seed_ip}:{seed_port}:{seed_rank}"
|
|
|
|
def add_seed(
|
|
self,
|
|
*,
|
|
seed_key: str,
|
|
seed_ip: str,
|
|
seed_port: int,
|
|
seed_rank: int,
|
|
resource_total: int | None = None,
|
|
) -> SeedRecord:
|
|
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
|
|
now = self._time()
|
|
total = self._default_resource_points if resource_total is None else max(resource_total, 1)
|
|
|
|
with self._lock:
|
|
current = self._seeds.get(identity)
|
|
if current is None:
|
|
current = SeedRecord(
|
|
seed_key=seed_key,
|
|
seed_ip=seed_ip,
|
|
seed_port=seed_port,
|
|
seed_rank=seed_rank,
|
|
last_heartbeat_ts=now,
|
|
resource_total=total,
|
|
resource_used=0,
|
|
)
|
|
self._seeds[identity] = current
|
|
self._seeds_by_key[seed_key].add(identity)
|
|
return current
|
|
|
|
if current.seed_key != seed_key:
|
|
self._seeds_by_key[current.seed_key].discard(identity)
|
|
self._seeds_by_key[seed_key].add(identity)
|
|
current.seed_key = seed_key
|
|
current.last_heartbeat_ts = now
|
|
if resource_total is not None:
|
|
current.resource_total = max(resource_total, 1)
|
|
if current.resource_used > current.resource_total:
|
|
current.resource_used = current.resource_total
|
|
return current
|
|
|
|
def get_seed(self, *, seed_key: str) -> tuple[SeedRecord, LeaseRecord] | None:
|
|
with self._lock:
|
|
self.gc_stale_seeds_locked()
|
|
seed_identities = self._seeds_by_key.get(seed_key, set())
|
|
seeds = [self._seeds[sid] for sid in seed_identities if sid in self._seeds]
|
|
selected = self._scheduler.choose_seed(seeds)
|
|
if selected is None:
|
|
return None
|
|
|
|
selected.resource_used += 1
|
|
user_id = uuid.uuid4().hex
|
|
lease = LeaseRecord(
|
|
user_id=user_id,
|
|
seed_key=seed_key,
|
|
seed_identity=selected.identity,
|
|
allocated_points=1,
|
|
leased_at=self._time(),
|
|
)
|
|
self._leases[user_id] = lease
|
|
return selected, lease
|
|
|
|
def put_seed(self, *, seed_ip: str, seed_port: int, seed_rank: int, user_id: str) -> bool:
|
|
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
|
|
with self._lock:
|
|
lease = self._leases.get(user_id)
|
|
if lease is None:
|
|
return False
|
|
if lease.seed_identity != identity:
|
|
return False
|
|
|
|
seed = self._seeds.get(identity)
|
|
if seed is not None:
|
|
seed.resource_used = max(0, seed.resource_used - lease.allocated_points)
|
|
del self._leases[user_id]
|
|
return True
|
|
|
|
def gc_stale_seeds(self) -> int:
|
|
with self._lock:
|
|
return self.gc_stale_seeds_locked()
|
|
|
|
def gc_stale_seeds_locked(self) -> int:
|
|
now = self._time()
|
|
stale_ids = [
|
|
sid for sid, seed in self._seeds.items() if (now - seed.last_heartbeat_ts) > self._heartbeat_ttl_sec
|
|
]
|
|
if not stale_ids:
|
|
return 0
|
|
|
|
stale_set = set(stale_ids)
|
|
for sid in stale_ids:
|
|
seed = self._seeds.pop(sid)
|
|
self._seeds_by_key[seed.seed_key].discard(sid)
|
|
if not self._seeds_by_key[seed.seed_key]:
|
|
del self._seeds_by_key[seed.seed_key]
|
|
|
|
lease_ids = [uid for uid, lease in self._leases.items() if lease.seed_identity in stale_set]
|
|
for uid in lease_ids:
|
|
del self._leases[uid]
|
|
|
|
return len(stale_ids)
|
|
|
|
def debug_snapshot(self) -> dict[str, object]:
|
|
with self._lock:
|
|
return {
|
|
"seed_count": len(self._seeds),
|
|
"lease_count": len(self._leases),
|
|
"seeds": {
|
|
sid: {
|
|
"seed_key": s.seed_key,
|
|
"resource_total": s.resource_total,
|
|
"resource_used": s.resource_used,
|
|
"last_heartbeat_ts": s.last_heartbeat_ts,
|
|
}
|
|
for sid, s in self._seeds.items()
|
|
},
|
|
"leases": {
|
|
uid: {
|
|
"seed_identity": lease.seed_identity,
|
|
"seed_key": lease.seed_key,
|
|
"allocated_points": lease.allocated_points,
|
|
}
|
|
for uid, lease in self._leases.items()
|
|
},
|
|
}
|
|
|
|
|
|
class HeartbeatGc:
|
|
def __init__(self, store: Store, sweep_interval_sec: int) -> None:
|
|
self._store = store
|
|
self._sweep_interval_sec = max(sweep_interval_sec, 1)
|
|
self._stop = threading.Event()
|
|
self._thread: threading.Thread | None = None
|
|
|
|
def start(self) -> None:
|
|
if self._thread is not None and self._thread.is_alive():
|
|
return
|
|
self._stop.clear()
|
|
self._thread = threading.Thread(target=self._run, name="rfork-heartbeat-gc", daemon=True)
|
|
self._thread.start()
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
if self._thread is not None:
|
|
self._thread.join(timeout=2)
|
|
|
|
def _run(self) -> None:
|
|
while not self._stop.is_set():
|
|
self._store.gc_stale_seeds()
|
|
time.sleep(self._sweep_interval_sec)
|
|
|
|
|
|
class HeaderError(ValueError):
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AddSeedHeaders:
|
|
seed_key: str
|
|
seed_ip: str
|
|
seed_port: int
|
|
seed_rank: int
|
|
seed_refcnt: int
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GetSeedHeaders:
|
|
seed_key: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class PutSeedHeaders:
|
|
seed_ip: str
|
|
seed_port: int
|
|
seed_rank: int
|
|
user_id: str
|
|
|
|
|
|
def _required(headers: Mapping[str, str], key: str) -> str:
|
|
value = headers.get(key)
|
|
if value is None or value == "":
|
|
raise HeaderError(f"missing required header: {key}")
|
|
return value
|
|
|
|
|
|
def _parse_int(value: str, key: str, *, minimum: int = 0) -> int:
|
|
try:
|
|
parsed = int(value)
|
|
except ValueError as exc:
|
|
raise HeaderError(f"invalid integer header {key}: {value}") from exc
|
|
if parsed < minimum:
|
|
raise HeaderError(f"header {key} must be >= {minimum}, got {parsed}")
|
|
return parsed
|
|
|
|
|
|
def parse_add_seed_headers(headers: Mapping[str, str]) -> AddSeedHeaders:
|
|
return AddSeedHeaders(
|
|
seed_key=_required(headers, "SEED_KEY"),
|
|
seed_ip=_required(headers, "SEED_IP"),
|
|
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
|
|
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
|
|
seed_refcnt=_parse_int(_required(headers, "SEED_REFCNT"), "SEED_REFCNT", minimum=0),
|
|
)
|
|
|
|
|
|
def parse_get_seed_headers(headers: Mapping[str, str]) -> GetSeedHeaders:
|
|
return GetSeedHeaders(seed_key=_required(headers, "SEED_KEY"))
|
|
|
|
|
|
def parse_put_seed_headers(headers: Mapping[str, str]) -> PutSeedHeaders:
|
|
return PutSeedHeaders(
|
|
seed_ip=_required(headers, "SEED_IP"),
|
|
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
|
|
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
|
|
user_id=_required(headers, "USER_ID"),
|
|
)
|
|
|
|
|
|
def build_router(store: Store):
|
|
router = APIRouter()
|
|
|
|
@router.post("/add_seed")
|
|
def add_seed(request: Request) -> Response:
|
|
try:
|
|
parsed = parse_add_seed_headers(request.headers)
|
|
except HeaderError as err:
|
|
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
store.add_seed(
|
|
seed_key=parsed.seed_key,
|
|
seed_ip=parsed.seed_ip,
|
|
seed_port=parsed.seed_port,
|
|
seed_rank=parsed.seed_rank,
|
|
# vLLM currently sends SEED_REFCNT=0 as heartbeat metadata.
|
|
# Capacity is controlled by planner config, not by this field.
|
|
resource_total=None,
|
|
)
|
|
return Response(status_code=status.HTTP_200_OK)
|
|
|
|
@router.get("/get_seed")
|
|
def get_seed(request: Request) -> Response:
|
|
try:
|
|
parsed = parse_get_seed_headers(request.headers)
|
|
except HeaderError as err:
|
|
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
result = store.get_seed(seed_key=parsed.seed_key)
|
|
if result is None:
|
|
return Response(content="no available seed", status_code=status.HTTP_404_NOT_FOUND)
|
|
|
|
seed, lease = result
|
|
response = Response(status_code=status.HTTP_200_OK)
|
|
response.headers["SEED_IP"] = seed.seed_ip
|
|
response.headers["SEED_PORT"] = str(seed.seed_port)
|
|
response.headers["SEED_RANK"] = str(seed.seed_rank)
|
|
response.headers["USER_ID"] = lease.user_id
|
|
return response
|
|
|
|
@router.post("/put_seed")
|
|
def put_seed(request: Request) -> Response:
|
|
try:
|
|
parsed = parse_put_seed_headers(request.headers)
|
|
except HeaderError as err:
|
|
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
|
|
|
|
released = store.put_seed(
|
|
seed_ip=parsed.seed_ip,
|
|
seed_port=parsed.seed_port,
|
|
seed_rank=parsed.seed_rank,
|
|
user_id=parsed.user_id,
|
|
)
|
|
if not released:
|
|
return Response(content="lease not found", status_code=status.HTTP_404_NOT_FOUND)
|
|
|
|
return Response(status_code=status.HTTP_200_OK)
|
|
|
|
@router.get("/healthz")
|
|
def healthz() -> dict[str, str]:
|
|
return {"status": "ok"}
|
|
|
|
@router.get("/debug/snapshot")
|
|
def debug_snapshot() -> dict[str, object]:
|
|
return store.debug_snapshot()
|
|
|
|
return router
|
|
|
|
|
|
def create_app(settings: Settings):
|
|
scheduler = Scheduler(settings.alloc_policy)
|
|
store = Store(
|
|
heartbeat_ttl_sec=settings.heartbeat_ttl_sec,
|
|
default_resource_points=settings.default_resource_points,
|
|
scheduler=scheduler,
|
|
)
|
|
gc_runner = HeartbeatGc(store, settings.heartbeat_sweep_sec)
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(_: FastAPI):
|
|
gc_runner.start()
|
|
try:
|
|
yield
|
|
finally:
|
|
gc_runner.stop()
|
|
|
|
app = FastAPI(title="rfork planner mock", version="0.1.0", lifespan=lifespan)
|
|
app.include_router(build_router(store))
|
|
|
|
app.state.settings = settings
|
|
app.state.store = store
|
|
app.state.gc_runner = gc_runner
|
|
return app
|
|
|
|
|
|
def _build_arg_parser() -> argparse.ArgumentParser:
|
|
defaults = Settings.from_env()
|
|
parser = argparse.ArgumentParser(description="Standalone rfork planner server")
|
|
parser.add_argument("--host", default=defaults.host, help="bind host (default: env RFORK_MOCK_HOST or 0.0.0.0)")
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=defaults.port,
|
|
help="bind port (default: env RFORK_MOCK_PORT or 1223)",
|
|
)
|
|
parser.add_argument(
|
|
"--heartbeat-ttl-sec",
|
|
type=int,
|
|
default=defaults.heartbeat_ttl_sec,
|
|
help="seed heartbeat ttl in seconds (default: env RFORK_MOCK_HEARTBEAT_TTL_SEC or 60)",
|
|
)
|
|
parser.add_argument(
|
|
"--heartbeat-sweep-sec",
|
|
type=int,
|
|
default=defaults.heartbeat_sweep_sec,
|
|
help="gc sweep interval in seconds (default: env RFORK_MOCK_HEARTBEAT_SWEEP_SEC or 5)",
|
|
)
|
|
parser.add_argument(
|
|
"--default-resource-points",
|
|
type=int,
|
|
default=defaults.default_resource_points,
|
|
help="default seed capacity points (default: env RFORK_MOCK_DEFAULT_RESOURCE_POINTS or 1)",
|
|
)
|
|
parser.add_argument(
|
|
"--alloc-policy",
|
|
choices=["fifo", "lru"],
|
|
default=defaults.alloc_policy,
|
|
help="seed allocation policy (default: env RFORK_MOCK_ALLOC_POLICY or fifo)",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
parser = _build_arg_parser()
|
|
args = parser.parse_args()
|
|
settings = Settings(
|
|
host=args.host,
|
|
port=args.port,
|
|
heartbeat_ttl_sec=args.heartbeat_ttl_sec,
|
|
heartbeat_sweep_sec=args.heartbeat_sweep_sec,
|
|
default_resource_points=args.default_resource_points,
|
|
alloc_policy=args.alloc_policy,
|
|
)
|
|
app = create_app(settings)
|
|
try:
|
|
import uvicorn
|
|
except ModuleNotFoundError as exc:
|
|
raise SystemExit("missing dependency: uvicorn. Install it with: python -m pip install uvicorn") from exc
|
|
uvicorn.run(app, host=settings.host, port=settings.port)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|