Files
xc-llm-ascend/examples/rfork/rfork_planner.py
Marck 17da96658f [ModelLoader][Feature] Add rfork support for fast model loading (#7392)
### What this PR does / why we need it?
Support an new load format: RFORK

For implementation details of this feature, please refer to #7441


### Does this PR introduce _any_ user-facing change?

add an new options for load-format: rfork

e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```

### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: Marck <1412354149@qq.com>
2026-03-25 16:40:30 +08:00

510 lines
17 KiB
Python

#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""
Standalone rfork planner mock server used by vLLM rfork seed protocol tests.
Usage:
python examples/rfork/rfork_planner.py --host 0.0.0.0 --port 1223
"""
from __future__ import annotations
import argparse
import os
import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping
from contextlib import asynccontextmanager
from dataclasses import dataclass
from fastapi import APIRouter, FastAPI, Request, Response, status
@dataclass(frozen=True)
class Settings:
host: str = "0.0.0.0"
port: int = 1223
heartbeat_ttl_sec: int = 60
heartbeat_sweep_sec: int = 5
default_resource_points: int = 1
alloc_policy: str = "fifo"
def __post_init__(self) -> None:
if self.port <= 0:
raise ValueError("port must be > 0")
if self.heartbeat_ttl_sec <= 0:
raise ValueError("heartbeat_ttl_sec must be > 0")
if self.heartbeat_sweep_sec <= 0:
raise ValueError("heartbeat_sweep_sec must be > 0")
if self.default_resource_points <= 0:
raise ValueError("default_resource_points must be > 0")
if self.alloc_policy not in {"fifo", "lru"}:
raise ValueError("alloc_policy must be one of: fifo, lru")
@staticmethod
def from_env() -> Settings:
return Settings(
host=os.getenv("RFORK_MOCK_HOST", "0.0.0.0"),
port=int(os.getenv("RFORK_MOCK_PORT", "1223")),
heartbeat_ttl_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_TTL_SEC", "60")),
heartbeat_sweep_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_SWEEP_SEC", "5")),
default_resource_points=int(os.getenv("RFORK_MOCK_DEFAULT_RESOURCE_POINTS", "1")),
alloc_policy=os.getenv("RFORK_MOCK_ALLOC_POLICY", "fifo").lower(),
)
@dataclass
class SeedRecord:
seed_key: str
seed_ip: str
seed_port: int
seed_rank: int
last_heartbeat_ts: float
resource_total: int
resource_used: int = 0
@property
def identity(self) -> str:
return f"{self.seed_ip}:{self.seed_port}:{self.seed_rank}"
@property
def available_points(self) -> int:
return max(self.resource_total - self.resource_used, 0)
@dataclass
class LeaseRecord:
user_id: str
seed_key: str
seed_identity: str
allocated_points: int
leased_at: float
class Scheduler:
def __init__(self, alloc_policy: str = "fifo") -> None:
if alloc_policy not in {"fifo", "lru"}:
raise ValueError(f"unsupported alloc policy: {alloc_policy}")
self.alloc_policy = alloc_policy
def choose_seed(self, seeds: Iterable[SeedRecord]) -> SeedRecord | None:
candidates = [seed for seed in seeds if seed.available_points > 0]
if not candidates:
return None
if self.alloc_policy == "fifo":
return min(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
return max(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
class Store:
def __init__(
self,
*,
heartbeat_ttl_sec: int,
default_resource_points: int,
scheduler: Scheduler,
time_fn: Callable[[], float] | None = None,
) -> None:
self._lock = threading.RLock()
self._seeds: dict[str, SeedRecord] = {}
self._seeds_by_key: dict[str, set[str]] = defaultdict(set)
self._leases: dict[str, LeaseRecord] = {}
self._heartbeat_ttl_sec = heartbeat_ttl_sec
self._default_resource_points = default_resource_points
self._scheduler = scheduler
self._time = time_fn or time.time
@staticmethod
def _seed_identity(seed_ip: str, seed_port: int, seed_rank: int) -> str:
return f"{seed_ip}:{seed_port}:{seed_rank}"
def add_seed(
self,
*,
seed_key: str,
seed_ip: str,
seed_port: int,
seed_rank: int,
resource_total: int | None = None,
) -> SeedRecord:
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
now = self._time()
total = self._default_resource_points if resource_total is None else max(resource_total, 1)
with self._lock:
current = self._seeds.get(identity)
if current is None:
current = SeedRecord(
seed_key=seed_key,
seed_ip=seed_ip,
seed_port=seed_port,
seed_rank=seed_rank,
last_heartbeat_ts=now,
resource_total=total,
resource_used=0,
)
self._seeds[identity] = current
self._seeds_by_key[seed_key].add(identity)
return current
if current.seed_key != seed_key:
self._seeds_by_key[current.seed_key].discard(identity)
self._seeds_by_key[seed_key].add(identity)
current.seed_key = seed_key
current.last_heartbeat_ts = now
if resource_total is not None:
current.resource_total = max(resource_total, 1)
if current.resource_used > current.resource_total:
current.resource_used = current.resource_total
return current
def get_seed(self, *, seed_key: str) -> tuple[SeedRecord, LeaseRecord] | None:
with self._lock:
self.gc_stale_seeds_locked()
seed_identities = self._seeds_by_key.get(seed_key, set())
seeds = [self._seeds[sid] for sid in seed_identities if sid in self._seeds]
selected = self._scheduler.choose_seed(seeds)
if selected is None:
return None
selected.resource_used += 1
user_id = uuid.uuid4().hex
lease = LeaseRecord(
user_id=user_id,
seed_key=seed_key,
seed_identity=selected.identity,
allocated_points=1,
leased_at=self._time(),
)
self._leases[user_id] = lease
return selected, lease
def put_seed(self, *, seed_ip: str, seed_port: int, seed_rank: int, user_id: str) -> bool:
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
with self._lock:
lease = self._leases.get(user_id)
if lease is None:
return False
if lease.seed_identity != identity:
return False
seed = self._seeds.get(identity)
if seed is not None:
seed.resource_used = max(0, seed.resource_used - lease.allocated_points)
del self._leases[user_id]
return True
def gc_stale_seeds(self) -> int:
with self._lock:
return self.gc_stale_seeds_locked()
def gc_stale_seeds_locked(self) -> int:
now = self._time()
stale_ids = [
sid for sid, seed in self._seeds.items() if (now - seed.last_heartbeat_ts) > self._heartbeat_ttl_sec
]
if not stale_ids:
return 0
stale_set = set(stale_ids)
for sid in stale_ids:
seed = self._seeds.pop(sid)
self._seeds_by_key[seed.seed_key].discard(sid)
if not self._seeds_by_key[seed.seed_key]:
del self._seeds_by_key[seed.seed_key]
lease_ids = [uid for uid, lease in self._leases.items() if lease.seed_identity in stale_set]
for uid in lease_ids:
del self._leases[uid]
return len(stale_ids)
def debug_snapshot(self) -> dict[str, object]:
with self._lock:
return {
"seed_count": len(self._seeds),
"lease_count": len(self._leases),
"seeds": {
sid: {
"seed_key": s.seed_key,
"resource_total": s.resource_total,
"resource_used": s.resource_used,
"last_heartbeat_ts": s.last_heartbeat_ts,
}
for sid, s in self._seeds.items()
},
"leases": {
uid: {
"seed_identity": lease.seed_identity,
"seed_key": lease.seed_key,
"allocated_points": lease.allocated_points,
}
for uid, lease in self._leases.items()
},
}
class HeartbeatGc:
def __init__(self, store: Store, sweep_interval_sec: int) -> None:
self._store = store
self._sweep_interval_sec = max(sweep_interval_sec, 1)
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
if self._thread is not None and self._thread.is_alive():
return
self._stop.clear()
self._thread = threading.Thread(target=self._run, name="rfork-heartbeat-gc", daemon=True)
self._thread.start()
def stop(self) -> None:
self._stop.set()
if self._thread is not None:
self._thread.join(timeout=2)
def _run(self) -> None:
while not self._stop.is_set():
self._store.gc_stale_seeds()
time.sleep(self._sweep_interval_sec)
class HeaderError(ValueError):
pass
@dataclass(frozen=True)
class AddSeedHeaders:
seed_key: str
seed_ip: str
seed_port: int
seed_rank: int
seed_refcnt: int
@dataclass(frozen=True)
class GetSeedHeaders:
seed_key: str
@dataclass(frozen=True)
class PutSeedHeaders:
seed_ip: str
seed_port: int
seed_rank: int
user_id: str
def _required(headers: Mapping[str, str], key: str) -> str:
value = headers.get(key)
if value is None or value == "":
raise HeaderError(f"missing required header: {key}")
return value
def _parse_int(value: str, key: str, *, minimum: int = 0) -> int:
try:
parsed = int(value)
except ValueError as exc:
raise HeaderError(f"invalid integer header {key}: {value}") from exc
if parsed < minimum:
raise HeaderError(f"header {key} must be >= {minimum}, got {parsed}")
return parsed
def parse_add_seed_headers(headers: Mapping[str, str]) -> AddSeedHeaders:
return AddSeedHeaders(
seed_key=_required(headers, "SEED_KEY"),
seed_ip=_required(headers, "SEED_IP"),
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
seed_refcnt=_parse_int(_required(headers, "SEED_REFCNT"), "SEED_REFCNT", minimum=0),
)
def parse_get_seed_headers(headers: Mapping[str, str]) -> GetSeedHeaders:
return GetSeedHeaders(seed_key=_required(headers, "SEED_KEY"))
def parse_put_seed_headers(headers: Mapping[str, str]) -> PutSeedHeaders:
return PutSeedHeaders(
seed_ip=_required(headers, "SEED_IP"),
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
user_id=_required(headers, "USER_ID"),
)
def build_router(store: Store):
router = APIRouter()
@router.post("/add_seed")
def add_seed(request: Request) -> Response:
try:
parsed = parse_add_seed_headers(request.headers)
except HeaderError as err:
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
store.add_seed(
seed_key=parsed.seed_key,
seed_ip=parsed.seed_ip,
seed_port=parsed.seed_port,
seed_rank=parsed.seed_rank,
# vLLM currently sends SEED_REFCNT=0 as heartbeat metadata.
# Capacity is controlled by planner config, not by this field.
resource_total=None,
)
return Response(status_code=status.HTTP_200_OK)
@router.get("/get_seed")
def get_seed(request: Request) -> Response:
try:
parsed = parse_get_seed_headers(request.headers)
except HeaderError as err:
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
result = store.get_seed(seed_key=parsed.seed_key)
if result is None:
return Response(content="no available seed", status_code=status.HTTP_404_NOT_FOUND)
seed, lease = result
response = Response(status_code=status.HTTP_200_OK)
response.headers["SEED_IP"] = seed.seed_ip
response.headers["SEED_PORT"] = str(seed.seed_port)
response.headers["SEED_RANK"] = str(seed.seed_rank)
response.headers["USER_ID"] = lease.user_id
return response
@router.post("/put_seed")
def put_seed(request: Request) -> Response:
try:
parsed = parse_put_seed_headers(request.headers)
except HeaderError as err:
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
released = store.put_seed(
seed_ip=parsed.seed_ip,
seed_port=parsed.seed_port,
seed_rank=parsed.seed_rank,
user_id=parsed.user_id,
)
if not released:
return Response(content="lease not found", status_code=status.HTTP_404_NOT_FOUND)
return Response(status_code=status.HTTP_200_OK)
@router.get("/healthz")
def healthz() -> dict[str, str]:
return {"status": "ok"}
@router.get("/debug/snapshot")
def debug_snapshot() -> dict[str, object]:
return store.debug_snapshot()
return router
def create_app(settings: Settings):
scheduler = Scheduler(settings.alloc_policy)
store = Store(
heartbeat_ttl_sec=settings.heartbeat_ttl_sec,
default_resource_points=settings.default_resource_points,
scheduler=scheduler,
)
gc_runner = HeartbeatGc(store, settings.heartbeat_sweep_sec)
@asynccontextmanager
async def lifespan(_: FastAPI):
gc_runner.start()
try:
yield
finally:
gc_runner.stop()
app = FastAPI(title="rfork planner mock", version="0.1.0", lifespan=lifespan)
app.include_router(build_router(store))
app.state.settings = settings
app.state.store = store
app.state.gc_runner = gc_runner
return app
def _build_arg_parser() -> argparse.ArgumentParser:
defaults = Settings.from_env()
parser = argparse.ArgumentParser(description="Standalone rfork planner server")
parser.add_argument("--host", default=defaults.host, help="bind host (default: env RFORK_MOCK_HOST or 0.0.0.0)")
parser.add_argument(
"--port",
type=int,
default=defaults.port,
help="bind port (default: env RFORK_MOCK_PORT or 1223)",
)
parser.add_argument(
"--heartbeat-ttl-sec",
type=int,
default=defaults.heartbeat_ttl_sec,
help="seed heartbeat ttl in seconds (default: env RFORK_MOCK_HEARTBEAT_TTL_SEC or 60)",
)
parser.add_argument(
"--heartbeat-sweep-sec",
type=int,
default=defaults.heartbeat_sweep_sec,
help="gc sweep interval in seconds (default: env RFORK_MOCK_HEARTBEAT_SWEEP_SEC or 5)",
)
parser.add_argument(
"--default-resource-points",
type=int,
default=defaults.default_resource_points,
help="default seed capacity points (default: env RFORK_MOCK_DEFAULT_RESOURCE_POINTS or 1)",
)
parser.add_argument(
"--alloc-policy",
choices=["fifo", "lru"],
default=defaults.alloc_policy,
help="seed allocation policy (default: env RFORK_MOCK_ALLOC_POLICY or fifo)",
)
return parser
def main() -> None:
parser = _build_arg_parser()
args = parser.parse_args()
settings = Settings(
host=args.host,
port=args.port,
heartbeat_ttl_sec=args.heartbeat_ttl_sec,
heartbeat_sweep_sec=args.heartbeat_sweep_sec,
default_resource_points=args.default_resource_points,
alloc_policy=args.alloc_policy,
)
app = create_app(settings)
try:
import uvicorn
except ModuleNotFoundError as exc:
raise SystemExit("missing dependency: uvicorn. Install it with: python -m pip install uvicorn") from exc
uvicorn.run(app, host=settings.host, port=settings.port)
if __name__ == "__main__":
main()