[ModelLoader][Feature] Add rfork support for fast model loading (#7392)
### What this PR does / why we need it?
Support an new load format: RFORK
For implementation details of this feature, please refer to #7441
### Does this PR introduce _any_ user-facing change?
add an new options for load-format: rfork
e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: Marck <1412354149@qq.com>
This commit is contained in:
BIN
docs/source/user_guide/feature_guide/images/rfork_flowchart.jpg
Normal file
BIN
docs/source/user_guide/feature_guide/images/rfork_flowchart.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 160 KiB |
@@ -13,6 +13,7 @@ structured_output
|
||||
lora
|
||||
eplb_swift_balancer
|
||||
netloader
|
||||
rfork
|
||||
Multi_Token_Prediction
|
||||
dynamic_batch
|
||||
epd_disaggregation
|
||||
|
||||
125
docs/source/user_guide/feature_guide/rfork.md
Normal file
125
docs/source/user_guide/feature_guide/rfork.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# RFork Guide
|
||||
|
||||
This guide explains how to use **RFork** as a model-loader plugin in **vLLM Ascend**.
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
RFork is a warm-start weight loading path for vLLM Ascend. Instead of always reading model weights from storage, a new instance can request a compatible **seed** instance from an external planner, then pull weights directly from that seed through `YuanRong TransferEngine`.
|
||||
|
||||
The RFork loading flow in the current implementation is:
|
||||
|
||||
1. vLLM starts with `--load-format rfork`.
|
||||
2. RFork builds a **seed key** from the model identity and deployment topology.
|
||||
3. RFork asks the planner for an available seed matching that key.
|
||||
4. If a seed is returned, the new instance initializes the model structure on its local NPU, registers local weight memory, fetches the remote transfer-engine metadata from the seed, and performs batch weight transfer into local parameter buffers.
|
||||
5. If no seed is available, or any step fails, RFork cleans up and falls back to the default loader.
|
||||
6. After the instance finishes loading, it starts a local seed service and periodically reports heartbeat to the planner, so later instances can reuse it.
|
||||
|
||||
## Flowchart
|
||||
|
||||

|
||||
|
||||
## Application Scenarios
|
||||
|
||||
- **Scale-out after a first successful load**: The first instance may still load from storage, but later instances with the same deployment identity can reuse it as a seed and shorten startup time.
|
||||
- **Elastic serving clusters**: Because RFork asks a planner for available seeds, it fits clusters where instances are created and reclaimed dynamically.
|
||||
- **Topology-sensitive deployments**: RFork encodes `kv_role`, `node_rank`, `tp_rank`, and optional `draft` role into the seed key, so only topology-compatible instances are matched together.
|
||||
|
||||
---
|
||||
|
||||
## Usage
|
||||
|
||||
To enable RFork, pass `--load-format rfork` and provide RFork settings through `--model-loader-extra-config` as a JSON string.
|
||||
|
||||
### RFork Prerequisites
|
||||
|
||||
- Install the runtime dependency `YuanRong TransferEngine` on every RFork instance.
|
||||
- Run a planner service that implements the RFork seed protocol. A simple mock planner script is provided at [`rfork_planner.py`](../../../../examples/rfork/rfork_planner.py).
|
||||
|
||||
### Configuration Fields
|
||||
|
||||
| Field Name | Type | Description | Allowed Values / Notes |
|
||||
|------------|------|-------------|------------------------|
|
||||
| **model_url** | String | Logical model identifier used to build the RFork seed key. | Required for RFork transfer. Instances that should share seeds must use the same value. |
|
||||
| **model_deploy_strategy_name** | String | Deployment strategy identifier used together with `model_url` to build the seed key. | Required for RFork transfer. Instances that should share seeds must use the same value. |
|
||||
| **rfork_scheduler_url** | String | Base URL of the planner service used for seed allocation, release, and heartbeat. | Required for planner-based matching. Example: `http://127.0.0.1:1223`. |
|
||||
| **rfork_seed_timeout_sec** | Number | Timeout for waiting until the local seed HTTP service becomes healthy after startup. | Optional. Default: `30`. Must be greater than `0`. |
|
||||
| **rfork_seed_key_separator** | String | Separator used when building the RFork seed key string. | Optional. Default: `$`. Keep the same value across compatible instances. |
|
||||
|
||||
### How RFork Matches Seeds
|
||||
|
||||
RFork does not match instances by `model_url` alone. The local seed key is composed from:
|
||||
|
||||
- `model_url`
|
||||
- `model_deploy_strategy_name`
|
||||
- disaggregation mode derived from `kv_transfer_config.kv_role` or `kv_both`
|
||||
- `node_rank`
|
||||
- `tp_rank`
|
||||
- optional `draft` suffix when the worker runs as a draft model
|
||||
|
||||
This means two instances must agree on both model identity and deployment topology before the planner will treat them as interchangeable seeds.
|
||||
|
||||
---
|
||||
|
||||
## Example Commands & Placeholders
|
||||
|
||||
> Replace parts in `` `<...>` `` before running.
|
||||
|
||||
### 1. Install YuanRong TransferEngine
|
||||
|
||||
```shell
|
||||
pip install openyuanrong-transfer-engine
|
||||
```
|
||||
|
||||
### 2. Start the Planner
|
||||
|
||||
A simple planner implementation is provided at [`rfork_planner.py`](../../../../examples/rfork/rfork_planner.py).
|
||||
|
||||
```shell
|
||||
python rfork_planner.py \
|
||||
--host 0.0.0.0 \
|
||||
--port `<planner_port>`
|
||||
```
|
||||
|
||||
### 3. Start vLLM Instances
|
||||
|
||||
Use the same RFork startup command for both the first instance and later instances in the same deployment.
|
||||
|
||||
For the first instance, the planner usually has no compatible seed yet, so RFork falls back to the default loader. After loading finishes, that instance starts its local seed service and reports itself to the planner.
|
||||
|
||||
For later instances, if the planner can allocate a compatible seed, RFork will try to transfer weights from the existing seed instance before falling back to the default loader.
|
||||
|
||||
```shell
|
||||
export RFORK_CONFIG='{
|
||||
"model_url": "`<model_url>`",
|
||||
"model_deploy_strategy_name": "`<deploy_strategy>`",
|
||||
"rfork_scheduler_url": "http://`<planner_ip>`:`<planner_port>`"
|
||||
}'
|
||||
|
||||
vllm serve `<model_path>` \
|
||||
--tensor-parallel-size 1 \
|
||||
--served-model-name `<served_model_name>` \
|
||||
--port `<port>` \
|
||||
--load-format rfork \
|
||||
--model-loader-extra-config "${RFORK_CONFIG}"
|
||||
```
|
||||
|
||||
### Placeholder Descriptions
|
||||
|
||||
- `<model_path>`: Model path or model identifier passed to `vllm serve`.
|
||||
- `<served_model_name>`: Service name exposed by vLLM.
|
||||
- `<planner_ip>`: IP address or hostname of the RFork planner.
|
||||
- `<planner_port>`: Listening port of the RFork planner.
|
||||
- `<model_url>`: Stable model identity string used to build the RFork seed key.
|
||||
- `<deploy_strategy>`: Stable deployment-strategy name used to build the RFork seed key.
|
||||
- `<port>`: Serving port of the vLLM instance being started.
|
||||
|
||||
---
|
||||
|
||||
## Note & Caveats
|
||||
|
||||
- RFork requires `YuanRong TransferEngine` at runtime. If the package is missing, RFork cannot initialize the transfer backend.
|
||||
- If RFORK is used, **each worker process** must bind a listening port. That port is assigned randomly.
|
||||
- The example [`rfork_planner.py`](../../../../examples/rfork/rfork_planner.py) is only a simple mock implementation. If you need stronger scheduling, capacity management, or production-grade availability behavior, implement your own planner based on the RFork seed protocol.
|
||||
509
examples/rfork/rfork_planner.py
Normal file
509
examples/rfork/rfork_planner.py
Normal file
@@ -0,0 +1,509 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
"""
|
||||
Standalone rfork planner mock server used by vLLM rfork seed protocol tests.
|
||||
|
||||
Usage:
|
||||
python examples/rfork/rfork_planner.py --host 0.0.0.0 --port 1223
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
|
||||
from fastapi import APIRouter, FastAPI, Request, Response, status
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Settings:
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 1223
|
||||
heartbeat_ttl_sec: int = 60
|
||||
heartbeat_sweep_sec: int = 5
|
||||
default_resource_points: int = 1
|
||||
alloc_policy: str = "fifo"
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.port <= 0:
|
||||
raise ValueError("port must be > 0")
|
||||
if self.heartbeat_ttl_sec <= 0:
|
||||
raise ValueError("heartbeat_ttl_sec must be > 0")
|
||||
if self.heartbeat_sweep_sec <= 0:
|
||||
raise ValueError("heartbeat_sweep_sec must be > 0")
|
||||
if self.default_resource_points <= 0:
|
||||
raise ValueError("default_resource_points must be > 0")
|
||||
if self.alloc_policy not in {"fifo", "lru"}:
|
||||
raise ValueError("alloc_policy must be one of: fifo, lru")
|
||||
|
||||
@staticmethod
|
||||
def from_env() -> Settings:
|
||||
return Settings(
|
||||
host=os.getenv("RFORK_MOCK_HOST", "0.0.0.0"),
|
||||
port=int(os.getenv("RFORK_MOCK_PORT", "1223")),
|
||||
heartbeat_ttl_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_TTL_SEC", "60")),
|
||||
heartbeat_sweep_sec=int(os.getenv("RFORK_MOCK_HEARTBEAT_SWEEP_SEC", "5")),
|
||||
default_resource_points=int(os.getenv("RFORK_MOCK_DEFAULT_RESOURCE_POINTS", "1")),
|
||||
alloc_policy=os.getenv("RFORK_MOCK_ALLOC_POLICY", "fifo").lower(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SeedRecord:
|
||||
seed_key: str
|
||||
seed_ip: str
|
||||
seed_port: int
|
||||
seed_rank: int
|
||||
last_heartbeat_ts: float
|
||||
resource_total: int
|
||||
resource_used: int = 0
|
||||
|
||||
@property
|
||||
def identity(self) -> str:
|
||||
return f"{self.seed_ip}:{self.seed_port}:{self.seed_rank}"
|
||||
|
||||
@property
|
||||
def available_points(self) -> int:
|
||||
return max(self.resource_total - self.resource_used, 0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LeaseRecord:
|
||||
user_id: str
|
||||
seed_key: str
|
||||
seed_identity: str
|
||||
allocated_points: int
|
||||
leased_at: float
|
||||
|
||||
|
||||
class Scheduler:
|
||||
def __init__(self, alloc_policy: str = "fifo") -> None:
|
||||
if alloc_policy not in {"fifo", "lru"}:
|
||||
raise ValueError(f"unsupported alloc policy: {alloc_policy}")
|
||||
self.alloc_policy = alloc_policy
|
||||
|
||||
def choose_seed(self, seeds: Iterable[SeedRecord]) -> SeedRecord | None:
|
||||
candidates = [seed for seed in seeds if seed.available_points > 0]
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
if self.alloc_policy == "fifo":
|
||||
return min(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
|
||||
|
||||
return max(candidates, key=lambda s: (s.last_heartbeat_ts, s.identity))
|
||||
|
||||
|
||||
class Store:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
heartbeat_ttl_sec: int,
|
||||
default_resource_points: int,
|
||||
scheduler: Scheduler,
|
||||
time_fn: Callable[[], float] | None = None,
|
||||
) -> None:
|
||||
self._lock = threading.RLock()
|
||||
self._seeds: dict[str, SeedRecord] = {}
|
||||
self._seeds_by_key: dict[str, set[str]] = defaultdict(set)
|
||||
self._leases: dict[str, LeaseRecord] = {}
|
||||
self._heartbeat_ttl_sec = heartbeat_ttl_sec
|
||||
self._default_resource_points = default_resource_points
|
||||
self._scheduler = scheduler
|
||||
self._time = time_fn or time.time
|
||||
|
||||
@staticmethod
|
||||
def _seed_identity(seed_ip: str, seed_port: int, seed_rank: int) -> str:
|
||||
return f"{seed_ip}:{seed_port}:{seed_rank}"
|
||||
|
||||
def add_seed(
|
||||
self,
|
||||
*,
|
||||
seed_key: str,
|
||||
seed_ip: str,
|
||||
seed_port: int,
|
||||
seed_rank: int,
|
||||
resource_total: int | None = None,
|
||||
) -> SeedRecord:
|
||||
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
|
||||
now = self._time()
|
||||
total = self._default_resource_points if resource_total is None else max(resource_total, 1)
|
||||
|
||||
with self._lock:
|
||||
current = self._seeds.get(identity)
|
||||
if current is None:
|
||||
current = SeedRecord(
|
||||
seed_key=seed_key,
|
||||
seed_ip=seed_ip,
|
||||
seed_port=seed_port,
|
||||
seed_rank=seed_rank,
|
||||
last_heartbeat_ts=now,
|
||||
resource_total=total,
|
||||
resource_used=0,
|
||||
)
|
||||
self._seeds[identity] = current
|
||||
self._seeds_by_key[seed_key].add(identity)
|
||||
return current
|
||||
|
||||
if current.seed_key != seed_key:
|
||||
self._seeds_by_key[current.seed_key].discard(identity)
|
||||
self._seeds_by_key[seed_key].add(identity)
|
||||
current.seed_key = seed_key
|
||||
current.last_heartbeat_ts = now
|
||||
if resource_total is not None:
|
||||
current.resource_total = max(resource_total, 1)
|
||||
if current.resource_used > current.resource_total:
|
||||
current.resource_used = current.resource_total
|
||||
return current
|
||||
|
||||
def get_seed(self, *, seed_key: str) -> tuple[SeedRecord, LeaseRecord] | None:
|
||||
with self._lock:
|
||||
self.gc_stale_seeds_locked()
|
||||
seed_identities = self._seeds_by_key.get(seed_key, set())
|
||||
seeds = [self._seeds[sid] for sid in seed_identities if sid in self._seeds]
|
||||
selected = self._scheduler.choose_seed(seeds)
|
||||
if selected is None:
|
||||
return None
|
||||
|
||||
selected.resource_used += 1
|
||||
user_id = uuid.uuid4().hex
|
||||
lease = LeaseRecord(
|
||||
user_id=user_id,
|
||||
seed_key=seed_key,
|
||||
seed_identity=selected.identity,
|
||||
allocated_points=1,
|
||||
leased_at=self._time(),
|
||||
)
|
||||
self._leases[user_id] = lease
|
||||
return selected, lease
|
||||
|
||||
def put_seed(self, *, seed_ip: str, seed_port: int, seed_rank: int, user_id: str) -> bool:
|
||||
identity = self._seed_identity(seed_ip, seed_port, seed_rank)
|
||||
with self._lock:
|
||||
lease = self._leases.get(user_id)
|
||||
if lease is None:
|
||||
return False
|
||||
if lease.seed_identity != identity:
|
||||
return False
|
||||
|
||||
seed = self._seeds.get(identity)
|
||||
if seed is not None:
|
||||
seed.resource_used = max(0, seed.resource_used - lease.allocated_points)
|
||||
del self._leases[user_id]
|
||||
return True
|
||||
|
||||
def gc_stale_seeds(self) -> int:
|
||||
with self._lock:
|
||||
return self.gc_stale_seeds_locked()
|
||||
|
||||
def gc_stale_seeds_locked(self) -> int:
|
||||
now = self._time()
|
||||
stale_ids = [
|
||||
sid for sid, seed in self._seeds.items() if (now - seed.last_heartbeat_ts) > self._heartbeat_ttl_sec
|
||||
]
|
||||
if not stale_ids:
|
||||
return 0
|
||||
|
||||
stale_set = set(stale_ids)
|
||||
for sid in stale_ids:
|
||||
seed = self._seeds.pop(sid)
|
||||
self._seeds_by_key[seed.seed_key].discard(sid)
|
||||
if not self._seeds_by_key[seed.seed_key]:
|
||||
del self._seeds_by_key[seed.seed_key]
|
||||
|
||||
lease_ids = [uid for uid, lease in self._leases.items() if lease.seed_identity in stale_set]
|
||||
for uid in lease_ids:
|
||||
del self._leases[uid]
|
||||
|
||||
return len(stale_ids)
|
||||
|
||||
def debug_snapshot(self) -> dict[str, object]:
|
||||
with self._lock:
|
||||
return {
|
||||
"seed_count": len(self._seeds),
|
||||
"lease_count": len(self._leases),
|
||||
"seeds": {
|
||||
sid: {
|
||||
"seed_key": s.seed_key,
|
||||
"resource_total": s.resource_total,
|
||||
"resource_used": s.resource_used,
|
||||
"last_heartbeat_ts": s.last_heartbeat_ts,
|
||||
}
|
||||
for sid, s in self._seeds.items()
|
||||
},
|
||||
"leases": {
|
||||
uid: {
|
||||
"seed_identity": lease.seed_identity,
|
||||
"seed_key": lease.seed_key,
|
||||
"allocated_points": lease.allocated_points,
|
||||
}
|
||||
for uid, lease in self._leases.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class HeartbeatGc:
|
||||
def __init__(self, store: Store, sweep_interval_sec: int) -> None:
|
||||
self._store = store
|
||||
self._sweep_interval_sec = max(sweep_interval_sec, 1)
|
||||
self._stop = threading.Event()
|
||||
self._thread: threading.Thread | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
if self._thread is not None and self._thread.is_alive():
|
||||
return
|
||||
self._stop.clear()
|
||||
self._thread = threading.Thread(target=self._run, name="rfork-heartbeat-gc", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self._thread is not None:
|
||||
self._thread.join(timeout=2)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self._store.gc_stale_seeds()
|
||||
time.sleep(self._sweep_interval_sec)
|
||||
|
||||
|
||||
class HeaderError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddSeedHeaders:
|
||||
seed_key: str
|
||||
seed_ip: str
|
||||
seed_port: int
|
||||
seed_rank: int
|
||||
seed_refcnt: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GetSeedHeaders:
|
||||
seed_key: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PutSeedHeaders:
|
||||
seed_ip: str
|
||||
seed_port: int
|
||||
seed_rank: int
|
||||
user_id: str
|
||||
|
||||
|
||||
def _required(headers: Mapping[str, str], key: str) -> str:
|
||||
value = headers.get(key)
|
||||
if value is None or value == "":
|
||||
raise HeaderError(f"missing required header: {key}")
|
||||
return value
|
||||
|
||||
|
||||
def _parse_int(value: str, key: str, *, minimum: int = 0) -> int:
|
||||
try:
|
||||
parsed = int(value)
|
||||
except ValueError as exc:
|
||||
raise HeaderError(f"invalid integer header {key}: {value}") from exc
|
||||
if parsed < minimum:
|
||||
raise HeaderError(f"header {key} must be >= {minimum}, got {parsed}")
|
||||
return parsed
|
||||
|
||||
|
||||
def parse_add_seed_headers(headers: Mapping[str, str]) -> AddSeedHeaders:
|
||||
return AddSeedHeaders(
|
||||
seed_key=_required(headers, "SEED_KEY"),
|
||||
seed_ip=_required(headers, "SEED_IP"),
|
||||
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
|
||||
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
|
||||
seed_refcnt=_parse_int(_required(headers, "SEED_REFCNT"), "SEED_REFCNT", minimum=0),
|
||||
)
|
||||
|
||||
|
||||
def parse_get_seed_headers(headers: Mapping[str, str]) -> GetSeedHeaders:
|
||||
return GetSeedHeaders(seed_key=_required(headers, "SEED_KEY"))
|
||||
|
||||
|
||||
def parse_put_seed_headers(headers: Mapping[str, str]) -> PutSeedHeaders:
|
||||
return PutSeedHeaders(
|
||||
seed_ip=_required(headers, "SEED_IP"),
|
||||
seed_port=_parse_int(_required(headers, "SEED_PORT"), "SEED_PORT", minimum=1),
|
||||
seed_rank=_parse_int(_required(headers, "SEED_RANK"), "SEED_RANK", minimum=0),
|
||||
user_id=_required(headers, "USER_ID"),
|
||||
)
|
||||
|
||||
|
||||
def build_router(store: Store):
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/add_seed")
|
||||
def add_seed(request: Request) -> Response:
|
||||
try:
|
||||
parsed = parse_add_seed_headers(request.headers)
|
||||
except HeaderError as err:
|
||||
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
store.add_seed(
|
||||
seed_key=parsed.seed_key,
|
||||
seed_ip=parsed.seed_ip,
|
||||
seed_port=parsed.seed_port,
|
||||
seed_rank=parsed.seed_rank,
|
||||
# vLLM currently sends SEED_REFCNT=0 as heartbeat metadata.
|
||||
# Capacity is controlled by planner config, not by this field.
|
||||
resource_total=None,
|
||||
)
|
||||
return Response(status_code=status.HTTP_200_OK)
|
||||
|
||||
@router.get("/get_seed")
|
||||
def get_seed(request: Request) -> Response:
|
||||
try:
|
||||
parsed = parse_get_seed_headers(request.headers)
|
||||
except HeaderError as err:
|
||||
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
result = store.get_seed(seed_key=parsed.seed_key)
|
||||
if result is None:
|
||||
return Response(content="no available seed", status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
seed, lease = result
|
||||
response = Response(status_code=status.HTTP_200_OK)
|
||||
response.headers["SEED_IP"] = seed.seed_ip
|
||||
response.headers["SEED_PORT"] = str(seed.seed_port)
|
||||
response.headers["SEED_RANK"] = str(seed.seed_rank)
|
||||
response.headers["USER_ID"] = lease.user_id
|
||||
return response
|
||||
|
||||
@router.post("/put_seed")
|
||||
def put_seed(request: Request) -> Response:
|
||||
try:
|
||||
parsed = parse_put_seed_headers(request.headers)
|
||||
except HeaderError as err:
|
||||
return Response(content=str(err), status_code=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
released = store.put_seed(
|
||||
seed_ip=parsed.seed_ip,
|
||||
seed_port=parsed.seed_port,
|
||||
seed_rank=parsed.seed_rank,
|
||||
user_id=parsed.user_id,
|
||||
)
|
||||
if not released:
|
||||
return Response(content="lease not found", status_code=status.HTTP_404_NOT_FOUND)
|
||||
|
||||
return Response(status_code=status.HTTP_200_OK)
|
||||
|
||||
@router.get("/healthz")
|
||||
def healthz() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
|
||||
@router.get("/debug/snapshot")
|
||||
def debug_snapshot() -> dict[str, object]:
|
||||
return store.debug_snapshot()
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_app(settings: Settings):
|
||||
scheduler = Scheduler(settings.alloc_policy)
|
||||
store = Store(
|
||||
heartbeat_ttl_sec=settings.heartbeat_ttl_sec,
|
||||
default_resource_points=settings.default_resource_points,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
gc_runner = HeartbeatGc(store, settings.heartbeat_sweep_sec)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(_: FastAPI):
|
||||
gc_runner.start()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
gc_runner.stop()
|
||||
|
||||
app = FastAPI(title="rfork planner mock", version="0.1.0", lifespan=lifespan)
|
||||
app.include_router(build_router(store))
|
||||
|
||||
app.state.settings = settings
|
||||
app.state.store = store
|
||||
app.state.gc_runner = gc_runner
|
||||
return app
|
||||
|
||||
|
||||
def _build_arg_parser() -> argparse.ArgumentParser:
|
||||
defaults = Settings.from_env()
|
||||
parser = argparse.ArgumentParser(description="Standalone rfork planner server")
|
||||
parser.add_argument("--host", default=defaults.host, help="bind host (default: env RFORK_MOCK_HOST or 0.0.0.0)")
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=defaults.port,
|
||||
help="bind port (default: env RFORK_MOCK_PORT or 1223)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--heartbeat-ttl-sec",
|
||||
type=int,
|
||||
default=defaults.heartbeat_ttl_sec,
|
||||
help="seed heartbeat ttl in seconds (default: env RFORK_MOCK_HEARTBEAT_TTL_SEC or 60)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--heartbeat-sweep-sec",
|
||||
type=int,
|
||||
default=defaults.heartbeat_sweep_sec,
|
||||
help="gc sweep interval in seconds (default: env RFORK_MOCK_HEARTBEAT_SWEEP_SEC or 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--default-resource-points",
|
||||
type=int,
|
||||
default=defaults.default_resource_points,
|
||||
help="default seed capacity points (default: env RFORK_MOCK_DEFAULT_RESOURCE_POINTS or 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--alloc-policy",
|
||||
choices=["fifo", "lru"],
|
||||
default=defaults.alloc_policy,
|
||||
help="seed allocation policy (default: env RFORK_MOCK_ALLOC_POLICY or fifo)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = _build_arg_parser()
|
||||
args = parser.parse_args()
|
||||
settings = Settings(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
heartbeat_ttl_sec=args.heartbeat_ttl_sec,
|
||||
heartbeat_sweep_sec=args.heartbeat_sweep_sec,
|
||||
default_resource_points=args.default_resource_points,
|
||||
alloc_policy=args.alloc_policy,
|
||||
)
|
||||
app = create_app(settings)
|
||||
try:
|
||||
import uvicorn
|
||||
except ModuleNotFoundError as exc:
|
||||
raise SystemExit("missing dependency: uvicorn. Install it with: python -m pip install uvicorn") from exc
|
||||
uvicorn.run(app, host=settings.host, port=settings.port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -30,8 +30,10 @@ def register_connector():
|
||||
|
||||
def register_model_loader():
|
||||
from .model_loader.netloader import register_netloader
|
||||
from .model_loader.rfork import register_rforkloader
|
||||
|
||||
register_netloader()
|
||||
register_rforkloader()
|
||||
|
||||
|
||||
def register_service_profiling():
|
||||
|
||||
20
vllm_ascend/model_loader/rfork/__init__.py
Normal file
20
vllm_ascend/model_loader/rfork/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def register_rforkloader() -> None:
|
||||
"""Register the RFork model loader plugin."""
|
||||
from .rfork_loader import RForkModelLoader # noqa: F401
|
||||
188
vllm_ascend/model_loader/rfork/rfork_loader.py
Normal file
188
vllm_ascend/model_loader/rfork/rfork_loader.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Module
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.model_loader import register_model_loader
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from vllm_ascend.model_loader.rfork.rfork_worker import RForkWorker
|
||||
|
||||
|
||||
@register_model_loader("rfork")
|
||||
class RForkModelLoader(BaseModelLoader):
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
config = load_config.model_loader_extra_config
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError("RFork requires --model-loader-extra-config to be a JSON object.")
|
||||
|
||||
def _get_extra_config(key: str, default: str = "") -> str:
|
||||
value = config.get(key)
|
||||
return value if isinstance(value, str) and value else default
|
||||
|
||||
def _get_extra_config_float(key: str, default: float) -> float:
|
||||
value = config.get(key)
|
||||
parsed_value = default
|
||||
if isinstance(value, (int, float)):
|
||||
parsed_value = float(value)
|
||||
elif isinstance(value, str) and value:
|
||||
try:
|
||||
parsed_value = float(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
if parsed_value <= 0:
|
||||
return default
|
||||
|
||||
return parsed_value
|
||||
|
||||
self.model_url = _get_extra_config("model_url", "")
|
||||
self.model_deploy_strategy_name = _get_extra_config("model_deploy_strategy_name", "")
|
||||
self.scheduler_url = _get_extra_config("rfork_scheduler_url", "")
|
||||
self.seed_timeout_sec = _get_extra_config_float("rfork_seed_timeout_sec", 5.0)
|
||||
self.seed_key_separator = _get_extra_config("rfork_seed_key_separator", "$")
|
||||
|
||||
logger.info(
|
||||
"Initializing rfork with config: "
|
||||
"MODEL_URL=%s, MODEL_DEPLOY_STRATEGY_NAME=%s, "
|
||||
"SCHEDULER_URL=%s, SEED_TIMEOUT_SEC=%s, "
|
||||
"SEED_KEY_SEPARATOR=%s",
|
||||
self.model_url,
|
||||
self.model_deploy_strategy_name,
|
||||
self.scheduler_url,
|
||||
self.seed_timeout_sec,
|
||||
self.seed_key_separator,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _ensure_rfork_worker(self, vllm_config: VllmConfig) -> RForkWorker:
|
||||
rfork_worker = getattr(self.load_config, "rfork_worker", None)
|
||||
if rfork_worker is None:
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
disaggregation_mode = "kv_both" if kv_transfer_config is None else str(kv_transfer_config.kv_role)
|
||||
is_draft_model = (
|
||||
getattr(vllm_config.model_config, "runner_type", None) == "draft"
|
||||
or getattr(vllm_config.scheduler_config, "runner_type", None) == "draft"
|
||||
)
|
||||
device_id = torch.distributed.get_rank()
|
||||
self.load_config.rfork_worker = RForkWorker(
|
||||
disaggregation_mode=disaggregation_mode,
|
||||
node_rank=vllm_config.parallel_config.node_rank,
|
||||
tp_rank=get_tensor_model_parallel_rank(),
|
||||
device_id=device_id,
|
||||
scheduler_url=self.scheduler_url,
|
||||
model_url=self.model_url,
|
||||
model_deploy_strategy_name=self.model_deploy_strategy_name,
|
||||
seed_timeout_sec=self.seed_timeout_sec,
|
||||
seed_key_separator=self.seed_key_separator,
|
||||
is_draft_model=is_draft_model,
|
||||
)
|
||||
logger.info("RFork worker initialized, load_format=rfork")
|
||||
rfork_worker = self.load_config.rfork_worker
|
||||
return rfork_worker
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
prefix: str = "",
|
||||
) -> Module | None:
|
||||
device_config = vllm_config.device_config
|
||||
load_config = self.load_config
|
||||
load_device = device_config.device if load_config.device is None else load_config.device
|
||||
target_device = torch.device(load_device)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
need_del = False
|
||||
rfork_worker = self._ensure_rfork_worker(vllm_config)
|
||||
try:
|
||||
if not rfork_worker.is_seed_available():
|
||||
raise RuntimeError("seed is not available.")
|
||||
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config,
|
||||
model_config=model_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
need_del = True
|
||||
|
||||
weight_load_start_time = time.time()
|
||||
if not rfork_worker.pre_transfer(model):
|
||||
raise RuntimeError("pre_transfer failed.")
|
||||
if not rfork_worker.transfer(model):
|
||||
raise RuntimeError("transfer failed.")
|
||||
if not rfork_worker.post_transfer():
|
||||
raise RuntimeError("post_transfer failed.")
|
||||
logger.info(
|
||||
"Loading model weights took %.2f seconds",
|
||||
time.time() - weight_load_start_time,
|
||||
)
|
||||
|
||||
rfork_worker.start_seed_service(model)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
except Exception as e:
|
||||
logger.warning(f"RFork transfer failed: {e}, clean up and fall back to default loader")
|
||||
|
||||
rfork_worker.post_transfer()
|
||||
|
||||
if need_del:
|
||||
del model
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
self.load_config.load_format = "auto"
|
||||
self.load_config.model_loader_extra_config = {}
|
||||
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
model = get_model(
|
||||
vllm_config=vllm_config,
|
||||
model_config=model_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
try:
|
||||
rfork_worker.start_seed_service(model)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Fallback model loaded, but start_seed_service failed: %s",
|
||||
e,
|
||||
)
|
||||
return model
|
||||
119
vllm_ascend/model_loader/rfork/rfork_worker.py
Normal file
119
vllm_ascend/model_loader/rfork/rfork_worker.py
Normal file
@@ -0,0 +1,119 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import threading
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.model_loader.rfork.seed_protocol import RForkSeedProtocol
|
||||
from vllm_ascend.model_loader.rfork.seed_server import start_rfork_server
|
||||
from vllm_ascend.model_loader.rfork.transfer_backend import (
|
||||
RForkTransferBackend,
|
||||
)
|
||||
|
||||
|
||||
class RForkWorker:
|
||||
def __init__(
|
||||
self,
|
||||
disaggregation_mode: str,
|
||||
node_rank: int,
|
||||
tp_rank: int,
|
||||
device_id: int,
|
||||
scheduler_url: str,
|
||||
model_url: str,
|
||||
model_deploy_strategy_name: str,
|
||||
seed_timeout_sec: float = 30.0,
|
||||
seed_key_separator: str = "$",
|
||||
is_draft_model: bool = False,
|
||||
):
|
||||
self.device_id = device_id
|
||||
self.rfork_seed = None
|
||||
self.transfer_backend = RForkTransferBackend()
|
||||
self.ready_to_start_seed_service = False
|
||||
self.seed_service_started = False
|
||||
self.seed_timeout_sec = seed_timeout_sec
|
||||
self.seed_protocol = RForkSeedProtocol(
|
||||
disaggregation_mode=disaggregation_mode,
|
||||
node_rank=node_rank,
|
||||
tp_rank=tp_rank,
|
||||
scheduler_url=scheduler_url,
|
||||
model_url=model_url,
|
||||
model_deploy_strategy_name=model_deploy_strategy_name,
|
||||
seed_key_separator=seed_key_separator,
|
||||
is_draft_worker=is_draft_model,
|
||||
)
|
||||
|
||||
def is_seed_available(self) -> bool:
|
||||
self.rfork_seed = self.seed_protocol.get_seed()
|
||||
return self.rfork_seed is not None
|
||||
|
||||
def pre_transfer(self, model) -> bool:
|
||||
try:
|
||||
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot pre_transfer."
|
||||
result = self.transfer_backend.register_memory_region(model)
|
||||
self.ready_to_start_seed_service = result
|
||||
return result
|
||||
except AssertionError as e:
|
||||
logger.exception("Pre-transfer failed: %s", e)
|
||||
return False
|
||||
|
||||
def transfer(self, model) -> bool:
|
||||
try:
|
||||
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot transfer."
|
||||
assert self.rfork_seed is not None, "rfork seed is None, cannot transfer."
|
||||
return self.transfer_backend.recv_from_source(
|
||||
model=model,
|
||||
seed_instance_ip=self.rfork_seed["seed_ip"],
|
||||
seed_instance_service_port=self.rfork_seed["seed_port"],
|
||||
local_seed_key=self.seed_protocol.get_local_seed_key(),
|
||||
)
|
||||
except AssertionError as e:
|
||||
logger.exception("Transfer failed: %s", e)
|
||||
return False
|
||||
|
||||
def post_transfer(self):
|
||||
if self.rfork_seed is None:
|
||||
logger.info("rfork seed is None, no need to release.")
|
||||
return True
|
||||
self.seed_protocol.release_seed(self.rfork_seed)
|
||||
return True
|
||||
|
||||
def start_seed_service(self, model):
|
||||
if self.seed_service_started:
|
||||
logger.info("Seed service already started, skipping.")
|
||||
return
|
||||
|
||||
if not self.ready_to_start_seed_service:
|
||||
if not self.pre_transfer(model):
|
||||
return
|
||||
|
||||
port = start_rfork_server(
|
||||
self.seed_protocol.get_local_seed_key(),
|
||||
(
|
||||
self.transfer_backend.rfork_transfer_engine_session_id,
|
||||
self.transfer_backend.rfork_transfer_engine_weights_info_dict,
|
||||
),
|
||||
health_timeout_sec=self.seed_timeout_sec,
|
||||
)
|
||||
if port > 0:
|
||||
self.rfork_heartbeat_thread = threading.Thread(
|
||||
target=self.seed_protocol.report_seed,
|
||||
args=(port,),
|
||||
daemon=True,
|
||||
name="RForkHeartbeat",
|
||||
)
|
||||
self.rfork_heartbeat_thread.start()
|
||||
self.seed_service_started = True
|
||||
208
vllm_ascend/model_loader/rfork/seed_protocol.py
Normal file
208
vllm_ascend/model_loader/rfork/seed_protocol.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
from urllib.error import HTTPError
|
||||
|
||||
import requests
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
REQUEST_TIMEOUT_SEC = 10.0
|
||||
HEARTBEAT_LOG_EVERY_N = 4
|
||||
|
||||
|
||||
def get_local_seed_key(
|
||||
disaggregation_mode: str,
|
||||
node_rank: int,
|
||||
tp_rank: int,
|
||||
model_url: str,
|
||||
model_deploy_strategy_name: str,
|
||||
seed_key_separator: str = "$",
|
||||
is_draft_worker: bool = False,
|
||||
) -> str:
|
||||
if not model_url or not model_deploy_strategy_name:
|
||||
raise RuntimeError(
|
||||
"RFork seed key is not set. Ensure model_loader_extra_config contains "
|
||||
"`model_url` and `model_deploy_strategy_name`."
|
||||
)
|
||||
|
||||
seed_key = f"{model_url}{seed_key_separator}{model_deploy_strategy_name}"
|
||||
key_suffix = f"{disaggregation_mode}{seed_key_separator}{node_rank}{seed_key_separator}{tp_rank}"
|
||||
if is_draft_worker:
|
||||
key_suffix += f"{seed_key_separator}draft"
|
||||
return f"{seed_key}{seed_key_separator}{key_suffix}"
|
||||
|
||||
|
||||
class RForkSeedProtocol:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
disaggregation_mode: str,
|
||||
node_rank: int,
|
||||
tp_rank: int,
|
||||
scheduler_url: str,
|
||||
model_url: str,
|
||||
model_deploy_strategy_name: str,
|
||||
seed_key_separator: str = "$",
|
||||
is_draft_worker: bool = False,
|
||||
):
|
||||
self.disaggregation_mode = disaggregation_mode
|
||||
self.node_rank = node_rank
|
||||
self.tp_rank = tp_rank
|
||||
self.scheduler_url = scheduler_url
|
||||
self.model_url = model_url
|
||||
self.model_deploy_strategy_name = model_deploy_strategy_name
|
||||
self.seed_key_separator = seed_key_separator
|
||||
self.is_draft_worker = is_draft_worker
|
||||
|
||||
self._local_seed_key = get_local_seed_key(
|
||||
disaggregation_mode=self.disaggregation_mode,
|
||||
node_rank=self.node_rank,
|
||||
tp_rank=self.tp_rank,
|
||||
model_url=self.model_url,
|
||||
model_deploy_strategy_name=self.model_deploy_strategy_name,
|
||||
seed_key_separator=self.seed_key_separator,
|
||||
is_draft_worker=self.is_draft_worker,
|
||||
)
|
||||
|
||||
def get_local_seed_key(self) -> str:
|
||||
return self._local_seed_key
|
||||
|
||||
@staticmethod
|
||||
def _request_timeout_sec() -> float:
|
||||
return REQUEST_TIMEOUT_SEC
|
||||
|
||||
def _ensure_scheduler_url_set(self) -> None:
|
||||
if not self.scheduler_url:
|
||||
raise RuntimeError("rfork_scheduler_url is not set. Cannot interact with the scheduler.")
|
||||
|
||||
def get_seed(self):
|
||||
try:
|
||||
self._ensure_scheduler_url_set()
|
||||
response = requests.get(
|
||||
f"{self.scheduler_url}/get_seed",
|
||||
headers={
|
||||
"SEED_KEY": self.get_local_seed_key(),
|
||||
},
|
||||
timeout=self._request_timeout_sec(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Failed to get seed from the planner, {response.status_code}")
|
||||
|
||||
seed_ip = response.headers.get("SEED_IP")
|
||||
seed_port = response.headers.get("SEED_PORT")
|
||||
user_id = response.headers.get("USER_ID")
|
||||
seed_rank = response.headers.get("SEED_RANK")
|
||||
logger.debug(
|
||||
"seed_ip: %s, seed_port: %s, user_id: %s, seed_rank: %s",
|
||||
seed_ip,
|
||||
seed_port,
|
||||
user_id,
|
||||
seed_rank,
|
||||
)
|
||||
return {
|
||||
"seed_ip": seed_ip,
|
||||
"seed_port": seed_port,
|
||||
"user_id": user_id,
|
||||
"seed_rank": seed_rank,
|
||||
}
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.warning("get_seed from planner RuntimeError: %s", e)
|
||||
return None
|
||||
except HTTPError as e:
|
||||
logger.exception("get_seed from planner HTTPError: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("get_seed from planner Exception: %s", e)
|
||||
return None
|
||||
|
||||
def release_seed(self, seed) -> bool:
|
||||
try:
|
||||
self._ensure_scheduler_url_set()
|
||||
user_id = seed["user_id"]
|
||||
seed_ip = seed["seed_ip"]
|
||||
seed_port = str(seed["seed_port"])
|
||||
seed_rank = str(seed["seed_rank"])
|
||||
|
||||
response = requests.post(
|
||||
f"{self.scheduler_url}/put_seed",
|
||||
headers={
|
||||
"SEED_IP": seed_ip,
|
||||
"SEED_PORT": seed_port,
|
||||
"USER_ID": user_id,
|
||||
"SEED_RANK": seed_rank,
|
||||
},
|
||||
timeout=self._request_timeout_sec(),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Failed to release seed to the planner, {response.status_code}")
|
||||
return True
|
||||
except RuntimeError as e:
|
||||
logger.exception("release_seed to planner RuntimeError: %s", e)
|
||||
return False
|
||||
except HTTPError as e:
|
||||
logger.exception("release_seed to planner HTTPError: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("release_seed to planner Exception: %s", e)
|
||||
return False
|
||||
|
||||
def report_seed(self, port: int, sleep_interval: int = 30):
|
||||
heartbeat_idx = 0
|
||||
log_every_n = HEARTBEAT_LOG_EVERY_N
|
||||
try:
|
||||
self._ensure_scheduler_url_set()
|
||||
seed_ip = get_ip()
|
||||
seed_key = self.get_local_seed_key()
|
||||
except Exception as e:
|
||||
logger.exception("report_seed setup Exception: %s", e)
|
||||
return
|
||||
|
||||
while True:
|
||||
heartbeat_idx += 1
|
||||
result = False
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.scheduler_url}/add_seed",
|
||||
headers={
|
||||
"SEED_KEY": seed_key,
|
||||
"SEED_IP": seed_ip,
|
||||
"SEED_PORT": str(port),
|
||||
"SEED_RANK": str(self.tp_rank),
|
||||
"SEED_REFCNT": str(0),
|
||||
},
|
||||
timeout=self._request_timeout_sec(),
|
||||
)
|
||||
if response.status_code == 200:
|
||||
result = True
|
||||
except HTTPError as e:
|
||||
logger.exception("report_seed to planner HTTPError: %s", e)
|
||||
except Exception as e:
|
||||
logger.exception("report_seed to planner Exception: %s", e)
|
||||
|
||||
# Keep heartbeat frequency unchanged, but reduce log noise.
|
||||
# Always print failures immediately; print success once every N times.
|
||||
if (not result) or (heartbeat_idx % log_every_n == 0):
|
||||
logger.info(
|
||||
"[rfork_heartbeat] report seed to planner result: %s (%d/%d)",
|
||||
result,
|
||||
heartbeat_idx % log_every_n if heartbeat_idx % log_every_n != 0 else log_every_n,
|
||||
log_every_n,
|
||||
)
|
||||
time.sleep(sleep_interval)
|
||||
126
vllm_ascend/model_loader/rfork/seed_server.py
Normal file
126
vllm_ascend/model_loader/rfork/seed_server.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import queue
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import Response
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def start_fastapi_server(
|
||||
port_queue: queue.Queue[int],
|
||||
local_seed_key,
|
||||
info,
|
||||
):
|
||||
logger.info("[RFork Seed] Preparing socket with dynamic port...")
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("0.0.0.0", 0))
|
||||
_, port = sock.getsockname()
|
||||
logger.info("[RFork Seed] Assigned dynamic port: %s", port)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/get_rfork_transfer_engine_info")
|
||||
def get_rfork_transfer_engine_info(seed_key: str):
|
||||
if seed_key == local_seed_key:
|
||||
return {"rfork_transfer_engine_info": info}
|
||||
return {"rfork_transfer_engine_info": None}
|
||||
|
||||
@app.get("/rfork_fetch_seed")
|
||||
def rfork_fetch_seed():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/health_check_with_key")
|
||||
def health_check_with_key(seed_key: str):
|
||||
if seed_key == local_seed_key:
|
||||
return Response(status_code=HTTPStatus.OK)
|
||||
return Response(status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
config = uvicorn.Config(app, host=None, port=None, log_level="warning")
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
try:
|
||||
port_queue.put(port)
|
||||
except Exception as e:
|
||||
logger.error("[RFork Seed] Failed to send port via queue: %s", e)
|
||||
sock.close()
|
||||
return
|
||||
|
||||
logger.info("[RFork Seed] FastAPI server starting on port %s...", port)
|
||||
server.run(sockets=[sock])
|
||||
sock.close()
|
||||
|
||||
|
||||
def start_rfork_server(local_seed_key, rfork_transfer_engine_info, health_timeout_sec: float = 30.0) -> int:
|
||||
port_queue: queue.Queue[int] = queue.Queue()
|
||||
process = threading.Thread(
|
||||
target=start_fastapi_server,
|
||||
args=(port_queue, local_seed_key, rfork_transfer_engine_info),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
||||
try:
|
||||
port = port_queue.get(timeout=15)
|
||||
if port == -1:
|
||||
raise RuntimeError("Child process failed to start server")
|
||||
except Exception as e:
|
||||
logger.error("[RFork Seed] start server error: %s", e)
|
||||
return -1
|
||||
|
||||
deadline = time.time() + health_timeout_sec
|
||||
healthy = False
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
while time.time() < deadline:
|
||||
time.sleep(0.01)
|
||||
url = f"http://127.0.0.1:{port}/health_check_with_key"
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
params={"seed_key": local_seed_key},
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
healthy = True
|
||||
break
|
||||
last_error = f"unexpected status code {response.status_code} from health check"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
retry_count += 1
|
||||
if healthy:
|
||||
if retry_count > 1:
|
||||
logger.info(
|
||||
"[RFork Seed] health check passed after %d retries for port %s",
|
||||
retry_count - 1,
|
||||
port,
|
||||
)
|
||||
return port
|
||||
logger.error(
|
||||
"[RFork Seed] health check timed out after %.1fs for port %s, last error: %s",
|
||||
health_timeout_sec,
|
||||
port,
|
||||
last_error,
|
||||
)
|
||||
return -1
|
||||
212
vllm_ascend/model_loader/rfork/transfer_backend.py
Normal file
212
vllm_ascend/model_loader/rfork/transfer_backend.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip, get_open_port, join_host_port
|
||||
|
||||
|
||||
class RForkTransferBackend:
|
||||
def __init__(self):
|
||||
self.rfork_transfer_engine: Any | None = None
|
||||
self.rfork_transfer_engine_session_id = None
|
||||
self.rfork_transfer_engine_weights_info_dict = None
|
||||
self.registered_weight_blocks = []
|
||||
self._is_initialized = False
|
||||
self.init_transfer_engine()
|
||||
|
||||
def init_transfer_engine(self):
|
||||
try:
|
||||
from yr.datasystem import TransferEngine # type: ignore[import-not-found]
|
||||
except ImportError as e:
|
||||
raise ImportError("Please install @yuanrong-datasystem/transfer_engine first.") from e
|
||||
|
||||
transfer_engine = TransferEngine()
|
||||
local_hostname = join_host_port(get_ip(), get_open_port())
|
||||
ret = transfer_engine.initialize(local_hostname, "ascend", f"npu:{torch.npu.current_device()}")
|
||||
if ret.is_error():
|
||||
raise RuntimeError(
|
||||
"TransferEngine initialization failed: "
|
||||
f"initialize({local_hostname}, ascend"
|
||||
f"npu:{int(torch.npu.current_device())}) -> {ret.to_string()}"
|
||||
)
|
||||
|
||||
self.rfork_transfer_engine = transfer_engine
|
||||
self.rfork_transfer_engine_session_id = local_hostname
|
||||
self._is_initialized = True
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._is_initialized
|
||||
|
||||
def _get_transfer_engine(self) -> Any:
|
||||
if self.rfork_transfer_engine is None:
|
||||
raise RuntimeError("TransferEngine is not initialized.")
|
||||
return self.rfork_transfer_engine
|
||||
|
||||
def register_memory_region(self, model):
|
||||
transfer_engine = self._get_transfer_engine()
|
||||
start_reg_mr_tic = time.time()
|
||||
|
||||
weight_mr_dict = {}
|
||||
weight_addr_set = set()
|
||||
for name, weight in model.named_parameters():
|
||||
weight_mr_dict[name] = (
|
||||
weight.data_ptr(),
|
||||
weight.numel(),
|
||||
weight.element_size(),
|
||||
)
|
||||
weight_addr_set.add(weight.data_ptr())
|
||||
|
||||
memory_snapshot = torch.npu.memory.memory_snapshot()
|
||||
weight_blocks_for_reg_mr = []
|
||||
for segment in memory_snapshot:
|
||||
current_weight_block = None
|
||||
for block in segment.get("blocks", []):
|
||||
address = block.get("address", -1)
|
||||
size = block.get("size", -1)
|
||||
state = block.get("state", "")
|
||||
if address < 0 or size < 0 or state == "":
|
||||
continue
|
||||
if state == "active_allocated" and address in weight_addr_set:
|
||||
if current_weight_block is None:
|
||||
current_weight_block = (address, size)
|
||||
elif current_weight_block[0] + current_weight_block[1] == address:
|
||||
current_weight_block = (
|
||||
current_weight_block[0],
|
||||
current_weight_block[1] + size,
|
||||
)
|
||||
else:
|
||||
weight_blocks_for_reg_mr.append(current_weight_block)
|
||||
current_weight_block = (address, size)
|
||||
if current_weight_block is not None:
|
||||
weight_blocks_for_reg_mr.append(current_weight_block)
|
||||
|
||||
addresses, sizes = zip(*weight_blocks_for_reg_mr) if weight_blocks_for_reg_mr else ((), ())
|
||||
ret = transfer_engine.batch_register_memory(addresses, sizes)
|
||||
if ret.is_error():
|
||||
logger.error(
|
||||
"batch_register_memory failed for %d blocks, ret: %s",
|
||||
len(weight_blocks_for_reg_mr),
|
||||
ret.to_string(),
|
||||
)
|
||||
return False
|
||||
|
||||
self.rfork_transfer_engine_weights_info_dict = weight_mr_dict
|
||||
self.registered_weight_blocks = weight_blocks_for_reg_mr
|
||||
|
||||
logger.info(
|
||||
"register_memory_region time: %.4fs",
|
||||
time.time() - start_reg_mr_tic,
|
||||
)
|
||||
return True
|
||||
|
||||
def unregister_memory_region(self) -> bool:
|
||||
transfer_engine = self._get_transfer_engine()
|
||||
start_unreg_mr_tic = time.time()
|
||||
ret = transfer_engine.batch_unregister_memory([address for address, _ in self.registered_weight_blocks])
|
||||
if ret.is_error():
|
||||
logger.error(
|
||||
"batch_unregister_memory failed for %d blocks, ret: %s",
|
||||
len(self.registered_weight_blocks),
|
||||
ret.to_string(),
|
||||
)
|
||||
return False
|
||||
self.rfork_transfer_engine_weights_info_dict = None
|
||||
self.registered_weight_blocks = []
|
||||
logger.info(
|
||||
"unregister_memory_region time: %.4fs",
|
||||
time.time() - start_unreg_mr_tic,
|
||||
)
|
||||
return True
|
||||
|
||||
def recv_from_source(
|
||||
self,
|
||||
model,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
local_seed_key,
|
||||
):
|
||||
transfer_engine = self._get_transfer_engine()
|
||||
seed_url = f"http://{seed_instance_ip}:{seed_instance_service_port}"
|
||||
seed_session_id, seed_weight_info = get_remote_instance_transfer_engine_info(seed_url, local_seed_key)
|
||||
if seed_session_id is None or seed_weight_info is None:
|
||||
logger.error("Cannot get transfer engine session or weight info.")
|
||||
return False
|
||||
|
||||
seed_ptr_list = []
|
||||
client_ptr_list = []
|
||||
client_len_list = []
|
||||
for name, tensor in model.named_parameters():
|
||||
weight_info = seed_weight_info.get(name, None)
|
||||
if weight_info is None:
|
||||
logger.error("Cannot find weight info for %s.", name)
|
||||
return False
|
||||
|
||||
seed_ptr, seed_len, seed_size = weight_info
|
||||
if seed_len != tensor.numel() or seed_size != tensor.element_size():
|
||||
logger.error(
|
||||
"Weight info mismatch for %s, expected (%s, %s), got (%s, %s)",
|
||||
name,
|
||||
seed_len,
|
||||
seed_size,
|
||||
tensor.numel(),
|
||||
tensor.element_size(),
|
||||
)
|
||||
return False
|
||||
|
||||
seed_ptr_list.append(seed_ptr)
|
||||
client_ptr_list.append(tensor.data_ptr())
|
||||
client_len_list.append(tensor.numel() * tensor.element_size())
|
||||
|
||||
start_transfer_tic = time.time()
|
||||
ret = transfer_engine.batch_transfer_sync_read(
|
||||
seed_session_id,
|
||||
client_ptr_list,
|
||||
seed_ptr_list,
|
||||
client_len_list,
|
||||
)
|
||||
if ret.is_error():
|
||||
logger.error("Failed to transfer weights from remote instance, ret=%s", ret.to_string())
|
||||
return False
|
||||
|
||||
logger.info("transfer weights time: %.4fs", time.time() - start_transfer_tic)
|
||||
return True
|
||||
|
||||
|
||||
def get_remote_instance_transfer_engine_info(seed_url: str, local_seed_key: str):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{seed_url}/get_rfork_transfer_engine_info",
|
||||
params={"seed_key": local_seed_key},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error("request.get failed: %s", response.status_code)
|
||||
return None, None
|
||||
|
||||
data = response.json()
|
||||
info = data.get("rfork_transfer_engine_info", None)
|
||||
if info is not None and isinstance(info, list) and len(info) == 2:
|
||||
return info[0], info[1]
|
||||
|
||||
logger.error("Failed to get `rfork_transfer_engine_info` in response.")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Exception: %s", e)
|
||||
return None, None
|
||||
Reference in New Issue
Block a user