[ModelLoader][Feature] Add rfork support for fast model loading (#7392)

### What this PR does / why we need it?
Support an new load format: RFORK

For implementation details of this feature, please refer to #7441


### Does this PR introduce _any_ user-facing change?

add an new options for load-format: rfork

e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```

### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: Marck <1412354149@qq.com>
This commit is contained in:
Marck
2026-03-25 16:40:30 +08:00
committed by GitHub
parent 6ddfc41312
commit 17da96658f
11 changed files with 1510 additions and 0 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 160 KiB

View File

@@ -13,6 +13,7 @@ structured_output
lora
eplb_swift_balancer
netloader
rfork
Multi_Token_Prediction
dynamic_batch
epd_disaggregation

View File

@@ -0,0 +1,125 @@
# RFork Guide
This guide explains how to use **RFork** as a model-loader plugin in **vLLM Ascend**.
---
## Overview
RFork is a warm-start weight loading path for vLLM Ascend. Instead of always reading model weights from storage, a new instance can request a compatible **seed** instance from an external planner, then pull weights directly from that seed through `YuanRong TransferEngine`.
The RFork loading flow in the current implementation is:
1. vLLM starts with `--load-format rfork`.
2. RFork builds a **seed key** from the model identity and deployment topology.
3. RFork asks the planner for an available seed matching that key.
4. If a seed is returned, the new instance initializes the model structure on its local NPU, registers local weight memory, fetches the remote transfer-engine metadata from the seed, and performs batch weight transfer into local parameter buffers.
5. If no seed is available, or any step fails, RFork cleans up and falls back to the default loader.
6. After the instance finishes loading, it starts a local seed service and periodically reports heartbeat to the planner, so later instances can reuse it.
## Flowchart
![rfork flowchart](./images/rfork_flowchart.jpg)
## Application Scenarios
- **Scale-out after a first successful load**: The first instance may still load from storage, but later instances with the same deployment identity can reuse it as a seed and shorten startup time.
- **Elastic serving clusters**: Because RFork asks a planner for available seeds, it fits clusters where instances are created and reclaimed dynamically.
- **Topology-sensitive deployments**: RFork encodes `kv_role`, `node_rank`, `tp_rank`, and optional `draft` role into the seed key, so only topology-compatible instances are matched together.
---
## Usage
To enable RFork, pass `--load-format rfork` and provide RFork settings through `--model-loader-extra-config` as a JSON string.
### RFork Prerequisites
- Install the runtime dependency `YuanRong TransferEngine` on every RFork instance.
- Run a planner service that implements the RFork seed protocol. A simple mock planner script is provided at [`rfork_planner.py`](../../../../examples/rfork/rfork_planner.py).
### Configuration Fields
| Field Name | Type | Description | Allowed Values / Notes |
|------------|------|-------------|------------------------|
| **model_url** | String | Logical model identifier used to build the RFork seed key. | Required for RFork transfer. Instances that should share seeds must use the same value. |
| **model_deploy_strategy_name** | String | Deployment strategy identifier used together with `model_url` to build the seed key. | Required for RFork transfer. Instances that should share seeds must use the same value. |
| **rfork_scheduler_url** | String | Base URL of the planner service used for seed allocation, release, and heartbeat. | Required for planner-based matching. Example: `http://127.0.0.1:1223`. |
| **rfork_seed_timeout_sec** | Number | Timeout for waiting until the local seed HTTP service becomes healthy after startup. | Optional. Default: `30`. Must be greater than `0`. |
| **rfork_seed_key_separator** | String | Separator used when building the RFork seed key string. | Optional. Default: `$`. Keep the same value across compatible instances. |
### How RFork Matches Seeds
RFork does not match instances by `model_url` alone. The local seed key is composed from:
- `model_url`
- `model_deploy_strategy_name`
- disaggregation mode derived from `kv_transfer_config.kv_role` or `kv_both`
- `node_rank`
- `tp_rank`
- optional `draft` suffix when the worker runs as a draft model
This means two instances must agree on both model identity and deployment topology before the planner will treat them as interchangeable seeds.
---
## Example Commands & Placeholders
> Replace parts in `` `<...>` `` before running.
### 1. Install YuanRong TransferEngine
```shell
pip install openyuanrong-transfer-engine
```
### 2. Start the Planner
A simple planner implementation is provided at [`rfork_planner.py`](../../../../examples/rfork/rfork_planner.py).
```shell
python rfork_planner.py \
--host 0.0.0.0 \
--port `<planner_port>`
```
### 3. Start vLLM Instances
Use the same RFork startup command for both the first instance and later instances in the same deployment.
For the first instance, the planner usually has no compatible seed yet, so RFork falls back to the default loader. After loading finishes, that instance starts its local seed service and reports itself to the planner.
For later instances, if the planner can allocate a compatible seed, RFork will try to transfer weights from the existing seed instance before falling back to the default loader.
```shell
export RFORK_CONFIG='{
"model_url": "`<model_url>`",
"model_deploy_strategy_name": "`<deploy_strategy>`",
"rfork_scheduler_url": "http://`<planner_ip>`:`<planner_port>`"
}'
vllm serve `<model_path>` \
--tensor-parallel-size 1 \
--served-model-name `<served_model_name>` \
--port `<port>` \
--load-format rfork \
--model-loader-extra-config "${RFORK_CONFIG}"
```
### Placeholder Descriptions
- `<model_path>`: Model path or model identifier passed to `vllm serve`.
- `<served_model_name>`: Service name exposed by vLLM.
- `<planner_ip>`: IP address or hostname of the RFork planner.
- `<planner_port>`: Listening port of the RFork planner.
- `<model_url>`: Stable model identity string used to build the RFork seed key.
- `<deploy_strategy>`: Stable deployment-strategy name used to build the RFork seed key.
- `<port>`: Serving port of the vLLM instance being started.
---
## Note & Caveats
- RFork requires `YuanRong TransferEngine` at runtime. If the package is missing, RFork cannot initialize the transfer backend.
- If RFORK is used, **each worker process** must bind a listening port. That port is assigned randomly.
- The example [`rfork_planner.py`](../../../../examples/rfork/rfork_planner.py) is only a simple mock implementation. If you need stronger scheduling, capacity management, or production-grade availability behavior, implement your own planner based on the RFork seed protocol.

View File

@@ -0,0 +1,509 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
"""
Standalone rfork planner mock server used by vLLM rfork seed protocol tests.
Usage:
python examples/rfork/rfork_planner.py --host 0.0.0.0 --port 1223
"""
from __future__ import annotations
import argparse
import os
import threading
import time
import uuid
from collections import defaultdict
from collections.abc import Callable, Iterable, Mapping
from contextlib import asynccontextmanager
from dataclasses import dataclass
from fastapi import APIRouter, FastAPI, Request, Response, status
@dataclass(frozen=True)
class Settings:
host: str = "0.0.0.0"
port: int = 1223
heartbeat_ttl_sec: int = 60
heartbeat_sweep_sec: int = 5
default_resource_points: int = 1
alloc_policy: str = "fifo"
def __post_init__(self) -> None:
if self.port <= 0:
raise ValueError("port must be > 0")
if self.heartbeat_ttl_sec <= 0:
raise ValueError("heartbeat_ttl_sec must be > 0")
if self.heartbeat_sweep_sec <= 0:
raise ValueError("heartbeat_sweep_sec must be > 0")
if self.default_resource_points <= 0:
raise ValueError("default_resource_points must be > 0")
if self.alloc_policy not in {"fifo", "lru"}:
raise ValueError("alloc_policy must be one of: fifo, lru")
@staticmethod
def from_env() -> Settings:
return Settings(
host=os.getenv("RFORK_MOCK_HOST", "0.0.0.0"),
port=int(os.getenv("RFORK_MOCK_PORT", "1223")),
heartbeat_ttl_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_TTL_SEC", "60")),
heartbeat_sweep_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_SWEEP_SEC", "5")),
default_resource_points=int(os.getenv("RFORK_MOCK_DEFAULT_RESOURCE_POINTS", "1")),
alloc_policy=os.getenv("RFORK_MOCK_ALLOC_POLICY", "fifo").lower(),
)
@dataclass
class SeedRecord:
seed_key: str
seed_ip: str
seed_port: int
seed_rank: int
last_heartbeat_ts: float
resource_total: int
resource_used: int = 0
@property
def identity(self) -> str:
return f"{self.seed_ip}:{self.seed_port}:{self.seed_rank}"
@property
def available_points(self) -> int:
return max(self.resource_total - self.resource_used, 0)
@dataclass
class LeaseRecord:
user_id: str
seed_key: str
seed_identity: str
allocated_points: int
leased_at: float
class Scheduler:
def __init__(self, alloc_policy: str = "fifo") -> None:
if alloc_policy not in {"fifo", "lru"}:
raise ValueError(f"unsupported alloc policy: {alloc_policy}")
self.alloc_policy = alloc_policy
def choose_seed(self, seeds: Iterable[SeedRecord]) -> SeedRecord | None:
candidates = [seed for seed in seeds if seed.available_points > 0]
if not candidates:
return None
if self.alloc_policy == "fifo":
return min(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
return max(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
class Store:
def __init__(
self,
*,
heartbeat_ttl_sec: int,
default_resource_points: int,
scheduler: Scheduler,
time_fn: Callable[[], float] | None = None,
) -> None:
self._lock = threading.RLock()
self._seeds: dict[str, SeedRecord] = {}
self._seeds_by_key: dict[str, set[str]] = defaultdict(set)
self._leases: dict[str, LeaseRecord] = {}
self._heartbeat_ttl_sec = heartbeat_ttl_sec
self._default_resource_points = default_resource_points
self._scheduler = scheduler
self._time = time_fn or time.time
@staticmethod
def _seed_identity(seed_ip: str, seed_port: int, seed_rank: int) -> str:
return f"{seed_ip}:{seed_port}:{seed_rank}"
def add_seed(
self,
*,
seed_key: str,
seed_ip: str,
seed_port: int,
seed_rank: int,
resource_total: int | None = None,
) -> SeedRecord:
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
now = self._time()
total = self._default_resource_points if resource_total is None else max(resource_total, 1)
with self._lock:
current = self._seeds.get(identity)
if current is None:
current = SeedRecord(
seed_key=seed_key,
seed_ip=seed_ip,
seed_port=seed_port,
seed_rank=seed_rank,
last_heartbeat_ts=now,
resource_total=total,
resource_used=0,
)
self._seeds[identity] = current
self._seeds_by_key[seed_key].add(identity)
return current
if current.seed_key != seed_key:
self._seeds_by_key[current.seed_key].discard(identity)
self._seeds_by_key[seed_key].add(identity)
current.seed_key = seed_key
current.last_heartbeat_ts = now
if resource_total is not None:
current.resource_total = max(resource_total, 1)
if current.resource_used > current.resource_total:
current.resource_used = current.resource_total
return current
def get_seed(self, *, seed_key: str) -> tuple[SeedRecord, LeaseRecord] | None:
with self._lock:
self.gc_stale_seeds_locked()
seed_identities = self._seeds_by_key.get(seed_key, set())
seeds = [self._seeds[sid] for sid in seed_identities if sid in self._seeds]
selected = self._scheduler.choose_seed(seeds)
if selected is None:
return None
selected.resource_used += 1
user_id = uuid.uuid4().hex
lease = LeaseRecord(
user_id=user_id,
seed_key=seed_key,
seed_identity=selected.identity,
allocated_points=1,
leased_at=self._time(),
)
self._leases[user_id] = lease
return selected, lease
def put_seed(self, *, seed_ip: str, seed_port: int, seed_rank: int, user_id: str) -> bool:
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
with self._lock:
lease = self._leases.get(user_id)
if lease is None:
return False
if lease.seed_identity != identity:
return False
seed = self._seeds.get(identity)
if seed is not None:
seed.resource_used = max(0, seed.resource_used - lease.allocated_points)
del self._leases[user_id]
return True
def gc_stale_seeds(self) -> int:
with self._lock:
return self.gc_stale_seeds_locked()
def gc_stale_seeds_locked(self) -> int:
now = self._time()
stale_ids = [
sid for sid, seed in self._seeds.items() if (now - seed.last_heartbeat_ts) > self._heartbeat_ttl_sec
]
if not stale_ids:
return 0
stale_set = set(stale_ids)
for sid in stale_ids:
seed = self._seeds.pop(sid)
self._seeds_by_key[seed.seed_key].discard(sid)
if not self._seeds_by_key[seed.seed_key]:
del self._seeds_by_key[seed.seed_key]
lease_ids = [uid for uid, lease in self._leases.items() if lease.seed_identity in stale_set]
for uid in lease_ids:
del self._leases[uid]
return len(stale_ids)
def debug_snapshot(self) -> dict[str, object]:
with self._lock:
return {
"seed_count": len(self._seeds),
"lease_count": len(self._leases),
"seeds": {
sid: {
"seed_key": s.seed_key,
"resource_total": s.resource_total,
"resource_used": s.resource_used,
"last_heartbeat_ts": s.last_heartbeat_ts,
}
for sid, s in self._seeds.items()
},
"leases": {
uid: {
"seed_identity": lease.seed_identity,
"seed_key": lease.seed_key,
"allocated_points": lease.allocated_points,
}
for uid, lease in self._leases.items()
},
}
class HeartbeatGc:
def __init__(self, store: Store, sweep_interval_sec: int) -> None:
self._store = store
self._sweep_interval_sec = max(sweep_interval_sec, 1)
self._stop = threading.Event()
self._thread: threading.Thread | None = None
def start(self) -> None:
if self._thread is not None and self._thread.is_alive():
return
self._stop.clear()
self._thread = threading.Thread(target=self._run, name="rfork-heartbeat-gc", daemon=True)
self._thread.start()
def stop(self) -> None:
self._stop.set()
if self._thread is not None:
self._thread.join(timeout=2)
def _run(self) -> None:
while not self._stop.is_set():
self._store.gc_stale_seeds()
time.sleep(self._sweep_interval_sec)
class HeaderError(ValueError):
pass
@dataclass(frozen=True)
class AddSeedHeaders:
seed_key: str
seed_ip: str
seed_port: int
seed_rank: int
seed_refcnt: int
@dataclass(frozen=True)
class GetSeedHeaders:
seed_key: str
@dataclass(frozen=True)
class PutSeedHeaders:
seed_ip: str
seed_port: int
seed_rank: int
user_id: str
def _required(headers: Mapping[str, str], key: str) -> str:
value = headers.get(key)
if value is None or value == "":
raise HeaderError(f"missing required header: {key}")
return value
def _parse_int(value: str, key: str, *, minimum: int = 0) -> int:
try:
parsed = int(value)
except ValueError as exc:
raise HeaderError(f"invalid integer header {key}: {value}") from exc
if parsed < minimum:
raise HeaderError(f"header {key} must be >= {minimum}, got {parsed}")
return parsed
def parse_add_seed_headers(headers: Mapping[str, str]) -> AddSeedHeaders:
return AddSeedHeaders(
seed_key=_required(headers, "SEED_KEY"),
seed_ip=_required(headers, "SEED_IP"),
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
seed_refcnt=_parse_int(_required(headers, "SEED_REFCNT"), "SEED_REFCNT", minimum=0),
)
def parse_get_seed_headers(headers: Mapping[str, str]) -> GetSeedHeaders:
return GetSeedHeaders(seed_key=_required(headers, "SEED_KEY"))
def parse_put_seed_headers(headers: Mapping[str, str]) -> PutSeedHeaders:
return PutSeedHeaders(
seed_ip=_required(headers, "SEED_IP"),
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
user_id=_required(headers, "USER_ID"),
)
def build_router(store: Store):
router = APIRouter()
@router.post("/add_seed")
def add_seed(request: Request) -> Response:
try:
parsed = parse_add_seed_headers(request.headers)
except HeaderError as err:
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
store.add_seed(
seed_key=parsed.seed_key,
seed_ip=parsed.seed_ip,
seed_port=parsed.seed_port,
seed_rank=parsed.seed_rank,
# vLLM currently sends SEED_REFCNT=0 as heartbeat metadata.
# Capacity is controlled by planner config, not by this field.
resource_total=None,
)
return Response(status_code=status.HTTP_200_OK)
@router.get("/get_seed")
def get_seed(request: Request) -> Response:
try:
parsed = parse_get_seed_headers(request.headers)
except HeaderError as err:
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
result = store.get_seed(seed_key=parsed.seed_key)
if result is None:
return Response(content="no available seed", status_code=status.HTTP_404_NOT_FOUND)
seed, lease = result
response = Response(status_code=status.HTTP_200_OK)
response.headers["SEED_IP"] = seed.seed_ip
response.headers["SEED_PORT"] = str(seed.seed_port)
response.headers["SEED_RANK"] = str(seed.seed_rank)
response.headers["USER_ID"] = lease.user_id
return response
@router.post("/put_seed")
def put_seed(request: Request) -> Response:
try:
parsed = parse_put_seed_headers(request.headers)
except HeaderError as err:
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
released = store.put_seed(
seed_ip=parsed.seed_ip,
seed_port=parsed.seed_port,
seed_rank=parsed.seed_rank,
user_id=parsed.user_id,
)
if not released:
return Response(content="lease not found", status_code=status.HTTP_404_NOT_FOUND)
return Response(status_code=status.HTTP_200_OK)
@router.get("/healthz")
def healthz() -> dict[str, str]:
return {"status": "ok"}
@router.get("/debug/snapshot")
def debug_snapshot() -> dict[str, object]:
return store.debug_snapshot()
return router
def create_app(settings: Settings):
scheduler = Scheduler(settings.alloc_policy)
store = Store(
heartbeat_ttl_sec=settings.heartbeat_ttl_sec,
default_resource_points=settings.default_resource_points,
scheduler=scheduler,
)
gc_runner = HeartbeatGc(store, settings.heartbeat_sweep_sec)
@asynccontextmanager
async def lifespan(_: FastAPI):
gc_runner.start()
try:
yield
finally:
gc_runner.stop()
app = FastAPI(title="rfork planner mock", version="0.1.0", lifespan=lifespan)
app.include_router(build_router(store))
app.state.settings = settings
app.state.store = store
app.state.gc_runner = gc_runner
return app
def _build_arg_parser() -> argparse.ArgumentParser:
defaults = Settings.from_env()
parser = argparse.ArgumentParser(description="Standalone rfork planner server")
parser.add_argument("--host", default=defaults.host, help="bind host (default: env RFORK_MOCK_HOST or 0.0.0.0)")
parser.add_argument(
"--port",
type=int,
default=defaults.port,
help="bind port (default: env RFORK_MOCK_PORT or 1223)",
)
parser.add_argument(
"--heartbeat-ttl-sec",
type=int,
default=defaults.heartbeat_ttl_sec,
help="seed heartbeat ttl in seconds (default: env RFORK_MOCK_HEARTBEAT_TTL_SEC or 60)",
)
parser.add_argument(
"--heartbeat-sweep-sec",
type=int,
default=defaults.heartbeat_sweep_sec,
help="gc sweep interval in seconds (default: env RFORK_MOCK_HEARTBEAT_SWEEP_SEC or 5)",
)
parser.add_argument(
"--default-resource-points",
type=int,
default=defaults.default_resource_points,
help="default seed capacity points (default: env RFORK_MOCK_DEFAULT_RESOURCE_POINTS or 1)",
)
parser.add_argument(
"--alloc-policy",
choices=["fifo", "lru"],
default=defaults.alloc_policy,
help="seed allocation policy (default: env RFORK_MOCK_ALLOC_POLICY or fifo)",
)
return parser
def main() -> None:
parser = _build_arg_parser()
args = parser.parse_args()
settings = Settings(
host=args.host,
port=args.port,
heartbeat_ttl_sec=args.heartbeat_ttl_sec,
heartbeat_sweep_sec=args.heartbeat_sweep_sec,
default_resource_points=args.default_resource_points,
alloc_policy=args.alloc_policy,
)
app = create_app(settings)
try:
import uvicorn
except ModuleNotFoundError as exc:
raise SystemExit("missing dependency: uvicorn. Install it with: python -m pip install uvicorn") from exc
uvicorn.run(app, host=settings.host, port=settings.port)
if __name__ == "__main__":
main()

View File

@@ -30,8 +30,10 @@ def register_connector():
def register_model_loader():
from .model_loader.netloader import register_netloader
from .model_loader.rfork import register_rforkloader
register_netloader()
register_rforkloader()
def register_service_profiling():

View File

@@ -0,0 +1,20 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def register_rforkloader() -> None:
"""Register the RFork model loader plugin."""
from .rfork_loader import RForkModelLoader # noqa: F401

View File

@@ -0,0 +1,188 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import time
import torch
import torch.nn as nn
from torch.nn import Module
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import logger
from vllm.model_executor.model_loader import register_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm_ascend.model_loader.rfork.rfork_worker import RForkWorker
@register_model_loader("rfork")
class RForkModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
config = load_config.model_loader_extra_config
if not isinstance(config, dict):
raise RuntimeError("RFork requires --model-loader-extra-config to be a JSON object.")
def _get_extra_config(key: str, default: str = "") -> str:
value = config.get(key)
return value if isinstance(value, str) and value else default
def _get_extra_config_float(key: str, default: float) -> float:
value = config.get(key)
parsed_value = default
if isinstance(value, (int, float)):
parsed_value = float(value)
elif isinstance(value, str) and value:
try:
parsed_value = float(value)
except ValueError:
return default
if parsed_value <= 0:
return default
return parsed_value
self.model_url = _get_extra_config("model_url", "")
self.model_deploy_strategy_name = _get_extra_config("model_deploy_strategy_name", "")
self.scheduler_url = _get_extra_config("rfork_scheduler_url", "")
self.seed_timeout_sec = _get_extra_config_float("rfork_seed_timeout_sec", 5.0)
self.seed_key_separator = _get_extra_config("rfork_seed_key_separator", "$")
logger.info(
"Initializing rfork with config: "
"MODEL_URL=%s, MODEL_DEPLOY_STRATEGY_NAME=%s, "
"SCHEDULER_URL=%s, SEED_TIMEOUT_SEC=%s, "
"SEED_KEY_SEPARATOR=%s",
self.model_url,
self.model_deploy_strategy_name,
self.scheduler_url,
self.seed_timeout_sec,
self.seed_key_separator,
)
def download_model(self, model_config: ModelConfig) -> None:
raise NotImplementedError
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
raise NotImplementedError
def _ensure_rfork_worker(self, vllm_config: VllmConfig) -> RForkWorker:
rfork_worker = getattr(self.load_config, "rfork_worker", None)
if rfork_worker is None:
kv_transfer_config = vllm_config.kv_transfer_config
disaggregation_mode = "kv_both" if kv_transfer_config is None else str(kv_transfer_config.kv_role)
is_draft_model = (
getattr(vllm_config.model_config, "runner_type", None) == "draft"
or getattr(vllm_config.scheduler_config, "runner_type", None) == "draft"
)
device_id = torch.distributed.get_rank()
self.load_config.rfork_worker = RForkWorker(
disaggregation_mode=disaggregation_mode,
node_rank=vllm_config.parallel_config.node_rank,
tp_rank=get_tensor_model_parallel_rank(),
device_id=device_id,
scheduler_url=self.scheduler_url,
model_url=self.model_url,
model_deploy_strategy_name=self.model_deploy_strategy_name,
seed_timeout_sec=self.seed_timeout_sec,
seed_key_separator=self.seed_key_separator,
is_draft_model=is_draft_model,
)
logger.info("RFork worker initialized, load_format=rfork")
rfork_worker = self.load_config.rfork_worker
return rfork_worker
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "",
) -> Module | None:
device_config = vllm_config.device_config
load_config = self.load_config
load_device = device_config.device if load_config.device is None else load_config.device
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
need_del = False
rfork_worker = self._ensure_rfork_worker(vllm_config)
try:
if not rfork_worker.is_seed_available():
raise RuntimeError("seed is not available.")
with target_device:
model = initialize_model(
vllm_config=vllm_config,
model_config=model_config,
prefix=prefix,
)
need_del = True
weight_load_start_time = time.time()
if not rfork_worker.pre_transfer(model):
raise RuntimeError("pre_transfer failed.")
if not rfork_worker.transfer(model):
raise RuntimeError("transfer failed.")
if not rfork_worker.post_transfer():
raise RuntimeError("post_transfer failed.")
logger.info(
"Loading model weights took %.2f seconds",
time.time() - weight_load_start_time,
)
rfork_worker.start_seed_service(model)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
except Exception as e:
logger.warning(f"RFork transfer failed: {e}, clean up and fall back to default loader")
rfork_worker.post_transfer()
if need_del:
del model
gc.collect()
torch.npu.empty_cache()
for _ in range(3):
gc.collect()
torch.npu.empty_cache()
self.load_config.load_format = "auto"
self.load_config.model_loader_extra_config = {}
from vllm.model_executor.model_loader import get_model
model = get_model(
vllm_config=vllm_config,
model_config=model_config,
prefix=prefix,
)
try:
rfork_worker.start_seed_service(model)
except Exception as e:
logger.warning(
"Fallback model loaded, but start_seed_service failed: %s",
e,
)
return model

View File

@@ -0,0 +1,119 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
from vllm.logger import logger
from vllm_ascend.model_loader.rfork.seed_protocol import RForkSeedProtocol
from vllm_ascend.model_loader.rfork.seed_server import start_rfork_server
from vllm_ascend.model_loader.rfork.transfer_backend import (
RForkTransferBackend,
)
class RForkWorker:
def __init__(
self,
disaggregation_mode: str,
node_rank: int,
tp_rank: int,
device_id: int,
scheduler_url: str,
model_url: str,
model_deploy_strategy_name: str,
seed_timeout_sec: float = 30.0,
seed_key_separator: str = "$",
is_draft_model: bool = False,
):
self.device_id = device_id
self.rfork_seed = None
self.transfer_backend = RForkTransferBackend()
self.ready_to_start_seed_service = False
self.seed_service_started = False
self.seed_timeout_sec = seed_timeout_sec
self.seed_protocol = RForkSeedProtocol(
disaggregation_mode=disaggregation_mode,
node_rank=node_rank,
tp_rank=tp_rank,
scheduler_url=scheduler_url,
model_url=model_url,
model_deploy_strategy_name=model_deploy_strategy_name,
seed_key_separator=seed_key_separator,
is_draft_worker=is_draft_model,
)
def is_seed_available(self) -> bool:
self.rfork_seed = self.seed_protocol.get_seed()
return self.rfork_seed is not None
def pre_transfer(self, model) -> bool:
try:
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot pre_transfer."
result = self.transfer_backend.register_memory_region(model)
self.ready_to_start_seed_service = result
return result
except AssertionError as e:
logger.exception("Pre-transfer failed: %s", e)
return False
def transfer(self, model) -> bool:
try:
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot transfer."
assert self.rfork_seed is not None, "rfork seed is None, cannot transfer."
return self.transfer_backend.recv_from_source(
model=model,
seed_instance_ip=self.rfork_seed["seed_ip"],
seed_instance_service_port=self.rfork_seed["seed_port"],
local_seed_key=self.seed_protocol.get_local_seed_key(),
)
except AssertionError as e:
logger.exception("Transfer failed: %s", e)
return False
def post_transfer(self):
if self.rfork_seed is None:
logger.info("rfork seed is None, no need to release.")
return True
self.seed_protocol.release_seed(self.rfork_seed)
return True
def start_seed_service(self, model):
if self.seed_service_started:
logger.info("Seed service already started, skipping.")
return
if not self.ready_to_start_seed_service:
if not self.pre_transfer(model):
return
port = start_rfork_server(
self.seed_protocol.get_local_seed_key(),
(
self.transfer_backend.rfork_transfer_engine_session_id,
self.transfer_backend.rfork_transfer_engine_weights_info_dict,
),
health_timeout_sec=self.seed_timeout_sec,
)
if port > 0:
self.rfork_heartbeat_thread = threading.Thread(
target=self.seed_protocol.report_seed,
args=(port,),
daemon=True,
name="RForkHeartbeat",
)
self.rfork_heartbeat_thread.start()
self.seed_service_started = True

View File

@@ -0,0 +1,208 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from urllib.error import HTTPError
import requests
from vllm.logger import logger
from vllm.utils.network_utils import get_ip
REQUEST_TIMEOUT_SEC = 10.0
HEARTBEAT_LOG_EVERY_N = 4
def get_local_seed_key(
disaggregation_mode: str,
node_rank: int,
tp_rank: int,
model_url: str,
model_deploy_strategy_name: str,
seed_key_separator: str = "$",
is_draft_worker: bool = False,
) -> str:
if not model_url or not model_deploy_strategy_name:
raise RuntimeError(
"RFork seed key is not set. Ensure model_loader_extra_config contains "
"`model_url` and `model_deploy_strategy_name`."
)
seed_key = f"{model_url}{seed_key_separator}{model_deploy_strategy_name}"
key_suffix = f"{disaggregation_mode}{seed_key_separator}{node_rank}{seed_key_separator}{tp_rank}"
if is_draft_worker:
key_suffix += f"{seed_key_separator}draft"
return f"{seed_key}{seed_key_separator}{key_suffix}"
class RForkSeedProtocol:
def __init__(
self,
*,
disaggregation_mode: str,
node_rank: int,
tp_rank: int,
scheduler_url: str,
model_url: str,
model_deploy_strategy_name: str,
seed_key_separator: str = "$",
is_draft_worker: bool = False,
):
self.disaggregation_mode = disaggregation_mode
self.node_rank = node_rank
self.tp_rank = tp_rank
self.scheduler_url = scheduler_url
self.model_url = model_url
self.model_deploy_strategy_name = model_deploy_strategy_name
self.seed_key_separator = seed_key_separator
self.is_draft_worker = is_draft_worker
self._local_seed_key = get_local_seed_key(
disaggregation_mode=self.disaggregation_mode,
node_rank=self.node_rank,
tp_rank=self.tp_rank,
model_url=self.model_url,
model_deploy_strategy_name=self.model_deploy_strategy_name,
seed_key_separator=self.seed_key_separator,
is_draft_worker=self.is_draft_worker,
)
def get_local_seed_key(self) -> str:
return self._local_seed_key
@staticmethod
def _request_timeout_sec() -> float:
return REQUEST_TIMEOUT_SEC
def _ensure_scheduler_url_set(self) -> None:
if not self.scheduler_url:
raise RuntimeError("rfork_scheduler_url is not set. Cannot interact with the scheduler.")
def get_seed(self):
try:
self._ensure_scheduler_url_set()
response = requests.get(
f"{self.scheduler_url}/get_seed",
headers={
"SEED_KEY": self.get_local_seed_key(),
},
timeout=self._request_timeout_sec(),
)
if response.status_code != 200:
raise RuntimeError(f"Failed to get seed from the planner, {response.status_code}")
seed_ip = response.headers.get("SEED_IP")
seed_port = response.headers.get("SEED_PORT")
user_id = response.headers.get("USER_ID")
seed_rank = response.headers.get("SEED_RANK")
logger.debug(
"seed_ip: %s, seed_port: %s, user_id: %s, seed_rank: %s",
seed_ip,
seed_port,
user_id,
seed_rank,
)
return {
"seed_ip": seed_ip,
"seed_port": seed_port,
"user_id": user_id,
"seed_rank": seed_rank,
}
except RuntimeError as e:
logger.warning("get_seed from planner RuntimeError: %s", e)
return None
except HTTPError as e:
logger.exception("get_seed from planner HTTPError: %s", e)
return None
except Exception as e:
logger.exception("get_seed from planner Exception: %s", e)
return None
def release_seed(self, seed) -> bool:
try:
self._ensure_scheduler_url_set()
user_id = seed["user_id"]
seed_ip = seed["seed_ip"]
seed_port = str(seed["seed_port"])
seed_rank = str(seed["seed_rank"])
response = requests.post(
f"{self.scheduler_url}/put_seed",
headers={
"SEED_IP": seed_ip,
"SEED_PORT": seed_port,
"USER_ID": user_id,
"SEED_RANK": seed_rank,
},
timeout=self._request_timeout_sec(),
)
if response.status_code != 200:
raise RuntimeError(f"Failed to release seed to the planner, {response.status_code}")
return True
except RuntimeError as e:
logger.exception("release_seed to planner RuntimeError: %s", e)
return False
except HTTPError as e:
logger.exception("release_seed to planner HTTPError: %s", e)
return False
except Exception as e:
logger.exception("release_seed to planner Exception: %s", e)
return False
def report_seed(self, port: int, sleep_interval: int = 30):
heartbeat_idx = 0
log_every_n = HEARTBEAT_LOG_EVERY_N
try:
self._ensure_scheduler_url_set()
seed_ip = get_ip()
seed_key = self.get_local_seed_key()
except Exception as e:
logger.exception("report_seed setup Exception: %s", e)
return
while True:
heartbeat_idx += 1
result = False
try:
response = requests.post(
f"{self.scheduler_url}/add_seed",
headers={
"SEED_KEY": seed_key,
"SEED_IP": seed_ip,
"SEED_PORT": str(port),
"SEED_RANK": str(self.tp_rank),
"SEED_REFCNT": str(0),
},
timeout=self._request_timeout_sec(),
)
if response.status_code == 200:
result = True
except HTTPError as e:
logger.exception("report_seed to planner HTTPError: %s", e)
except Exception as e:
logger.exception("report_seed to planner Exception: %s", e)
# Keep heartbeat frequency unchanged, but reduce log noise.
# Always print failures immediately; print success once every N times.
if (not result) or (heartbeat_idx % log_every_n == 0):
logger.info(
"[rfork_heartbeat] report seed to planner result: %s (%d/%d)",
result,
heartbeat_idx % log_every_n if heartbeat_idx % log_every_n != 0 else log_every_n,
log_every_n,
)
time.sleep(sleep_interval)

View File

@@ -0,0 +1,126 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import queue
import socket
import threading
import time
from http import HTTPStatus
import requests
import uvicorn
from fastapi import FastAPI
from fastapi.responses import Response
from vllm.logger import logger
def start_fastapi_server(
port_queue: queue.Queue[int],
local_seed_key,
info,
):
logger.info("[RFork Seed] Preparing socket with dynamic port...")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("0.0.0.0", 0))
_, port = sock.getsockname()
logger.info("[RFork Seed] Assigned dynamic port: %s", port)
app = FastAPI()
@app.get("/get_rfork_transfer_engine_info")
def get_rfork_transfer_engine_info(seed_key: str):
if seed_key == local_seed_key:
return {"rfork_transfer_engine_info": info}
return {"rfork_transfer_engine_info": None}
@app.get("/rfork_fetch_seed")
def rfork_fetch_seed():
return {"status": "ok"}
@app.get("/health_check_with_key")
def health_check_with_key(seed_key: str):
if seed_key == local_seed_key:
return Response(status_code=HTTPStatus.OK)
return Response(status_code=HTTPStatus.BAD_REQUEST)
config = uvicorn.Config(app, host=None, port=None, log_level="warning")
server = uvicorn.Server(config)
try:
port_queue.put(port)
except Exception as e:
logger.error("[RFork Seed] Failed to send port via queue: %s", e)
sock.close()
return
logger.info("[RFork Seed] FastAPI server starting on port %s...", port)
server.run(sockets=[sock])
sock.close()
def start_rfork_server(local_seed_key, rfork_transfer_engine_info, health_timeout_sec: float = 30.0) -> int:
port_queue: queue.Queue[int] = queue.Queue()
process = threading.Thread(
target=start_fastapi_server,
args=(port_queue, local_seed_key, rfork_transfer_engine_info),
daemon=True,
)
process.start()
try:
port = port_queue.get(timeout=15)
if port == -1:
raise RuntimeError("Child process failed to start server")
except Exception as e:
logger.error("[RFork Seed] start server error: %s", e)
return -1
deadline = time.time() + health_timeout_sec
healthy = False
retry_count = 0
last_error = None
while time.time() < deadline:
time.sleep(0.01)
url = f"http://127.0.0.1:{port}/health_check_with_key"
try:
response = requests.get(
url,
params={"seed_key": local_seed_key},
timeout=10,
)
if response.status_code == 200:
healthy = True
break
last_error = f"unexpected status code {response.status_code} from health check"
except Exception as e:
last_error = str(e)
retry_count += 1
if healthy:
if retry_count > 1:
logger.info(
"[RFork Seed] health check passed after %d retries for port %s",
retry_count - 1,
port,
)
return port
logger.error(
"[RFork Seed] health check timed out after %.1fs for port %s, last error: %s",
health_timeout_sec,
port,
last_error,
)
return -1

View File

@@ -0,0 +1,212 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from typing import Any
import requests
import torch
from vllm.logger import logger
from vllm.utils.network_utils import get_ip, get_open_port, join_host_port
class RForkTransferBackend:
def __init__(self):
self.rfork_transfer_engine: Any | None = None
self.rfork_transfer_engine_session_id = None
self.rfork_transfer_engine_weights_info_dict = None
self.registered_weight_blocks = []
self._is_initialized = False
self.init_transfer_engine()
def init_transfer_engine(self):
try:
from yr.datasystem import TransferEngine # type: ignore[import-not-found]
except ImportError as e:
raise ImportError("Please install @yuanrong-datasystem/transfer_engine first.") from e
transfer_engine = TransferEngine()
local_hostname = join_host_port(get_ip(), get_open_port())
ret = transfer_engine.initialize(local_hostname, "ascend", f"npu:{torch.npu.current_device()}")
if ret.is_error():
raise RuntimeError(
"TransferEngine initialization failed: "
f"initialize({local_hostname}, ascend"
f"npu:{int(torch.npu.current_device())}) -> {ret.to_string()}"
)
self.rfork_transfer_engine = transfer_engine
self.rfork_transfer_engine_session_id = local_hostname
self._is_initialized = True
def is_initialized(self) -> bool:
return self._is_initialized
def _get_transfer_engine(self) -> Any:
if self.rfork_transfer_engine is None:
raise RuntimeError("TransferEngine is not initialized.")
return self.rfork_transfer_engine
def register_memory_region(self, model):
transfer_engine = self._get_transfer_engine()
start_reg_mr_tic = time.time()
weight_mr_dict = {}
weight_addr_set = set()
for name, weight in model.named_parameters():
weight_mr_dict[name] = (
weight.data_ptr(),
weight.numel(),
weight.element_size(),
)
weight_addr_set.add(weight.data_ptr())
memory_snapshot = torch.npu.memory.memory_snapshot()
weight_blocks_for_reg_mr = []
for segment in memory_snapshot:
current_weight_block = None
for block in segment.get("blocks", []):
address = block.get("address", -1)
size = block.get("size", -1)
state = block.get("state", "")
if address < 0 or size < 0 or state == "":
continue
if state == "active_allocated" and address in weight_addr_set:
if current_weight_block is None:
current_weight_block = (address, size)
elif current_weight_block[0] + current_weight_block[1] == address:
current_weight_block = (
current_weight_block[0],
current_weight_block[1] + size,
)
else:
weight_blocks_for_reg_mr.append(current_weight_block)
current_weight_block = (address, size)
if current_weight_block is not None:
weight_blocks_for_reg_mr.append(current_weight_block)
addresses, sizes = zip(*weight_blocks_for_reg_mr) if weight_blocks_for_reg_mr else ((), ())
ret = transfer_engine.batch_register_memory(addresses, sizes)
if ret.is_error():
logger.error(
"batch_register_memory failed for %d blocks, ret: %s",
len(weight_blocks_for_reg_mr),
ret.to_string(),
)
return False
self.rfork_transfer_engine_weights_info_dict = weight_mr_dict
self.registered_weight_blocks = weight_blocks_for_reg_mr
logger.info(
"register_memory_region time: %.4fs",
time.time() - start_reg_mr_tic,
)
return True
def unregister_memory_region(self) -> bool:
transfer_engine = self._get_transfer_engine()
start_unreg_mr_tic = time.time()
ret = transfer_engine.batch_unregister_memory([address for address, _ in self.registered_weight_blocks])
if ret.is_error():
logger.error(
"batch_unregister_memory failed for %d blocks, ret: %s",
len(self.registered_weight_blocks),
ret.to_string(),
)
return False
self.rfork_transfer_engine_weights_info_dict = None
self.registered_weight_blocks = []
logger.info(
"unregister_memory_region time: %.4fs",
time.time() - start_unreg_mr_tic,
)
return True
def recv_from_source(
self,
model,
seed_instance_ip,
seed_instance_service_port,
local_seed_key,
):
transfer_engine = self._get_transfer_engine()
seed_url = f"http://{seed_instance_ip}:{seed_instance_service_port}"
seed_session_id, seed_weight_info = get_remote_instance_transfer_engine_info(seed_url, local_seed_key)
if seed_session_id is None or seed_weight_info is None:
logger.error("Cannot get transfer engine session or weight info.")
return False
seed_ptr_list = []
client_ptr_list = []
client_len_list = []
for name, tensor in model.named_parameters():
weight_info = seed_weight_info.get(name, None)
if weight_info is None:
logger.error("Cannot find weight info for %s.", name)
return False
seed_ptr, seed_len, seed_size = weight_info
if seed_len != tensor.numel() or seed_size != tensor.element_size():
logger.error(
"Weight info mismatch for %s, expected (%s, %s), got (%s, %s)",
name,
seed_len,
seed_size,
tensor.numel(),
tensor.element_size(),
)
return False
seed_ptr_list.append(seed_ptr)
client_ptr_list.append(tensor.data_ptr())
client_len_list.append(tensor.numel() * tensor.element_size())
start_transfer_tic = time.time()
ret = transfer_engine.batch_transfer_sync_read(
seed_session_id,
client_ptr_list,
seed_ptr_list,
client_len_list,
)
if ret.is_error():
logger.error("Failed to transfer weights from remote instance, ret=%s", ret.to_string())
return False
logger.info("transfer weights time: %.4fs", time.time() - start_transfer_tic)
return True
def get_remote_instance_transfer_engine_info(seed_url: str, local_seed_key: str):
try:
response = requests.get(
f"{seed_url}/get_rfork_transfer_engine_info",
params={"seed_key": local_seed_key},
)
if response.status_code != 200:
logger.error("request.get failed: %s", response.status_code)
return None, None
data = response.json()
info = data.get("rfork_transfer_engine_info", None)
if info is not None and isinstance(info, list) and len(info) == 2:
return info[0], info[1]
logger.error("Failed to get `rfork_transfer_engine_info` in response.")
return None, None
except Exception as e:
logger.error("Exception: %s", e)
return None, None