[ModelLoader][Feature] Add rfork support for fast model loading (#7392)
### What this PR does / why we need it?
Support an new load format: RFORK
For implementation details of this feature, please refer to #7441
### Does this PR introduce _any_ user-facing change?
add an new options for load-format: rfork
e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: Marck <1412354149@qq.com>
This commit is contained in:
@@ -30,8 +30,10 @@ def register_connector():
|
||||
|
||||
def register_model_loader():
|
||||
from .model_loader.netloader import register_netloader
|
||||
from .model_loader.rfork import register_rforkloader
|
||||
|
||||
register_netloader()
|
||||
register_rforkloader()
|
||||
|
||||
|
||||
def register_service_profiling():
|
||||
|
||||
20
vllm_ascend/model_loader/rfork/__init__.py
Normal file
20
vllm_ascend/model_loader/rfork/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def register_rforkloader() -> None:
|
||||
"""Register the RFork model loader plugin."""
|
||||
from .rfork_loader import RForkModelLoader # noqa: F401
|
||||
188
vllm_ascend/model_loader/rfork/rfork_loader.py
Normal file
188
vllm_ascend/model_loader/rfork/rfork_loader.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import Module
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.model_loader import register_model_loader
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
from vllm_ascend.model_loader.rfork.rfork_worker import RForkWorker
|
||||
|
||||
|
||||
@register_model_loader("rfork")
|
||||
class RForkModelLoader(BaseModelLoader):
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
config = load_config.model_loader_extra_config
|
||||
if not isinstance(config, dict):
|
||||
raise RuntimeError("RFork requires --model-loader-extra-config to be a JSON object.")
|
||||
|
||||
def _get_extra_config(key: str, default: str = "") -> str:
|
||||
value = config.get(key)
|
||||
return value if isinstance(value, str) and value else default
|
||||
|
||||
def _get_extra_config_float(key: str, default: float) -> float:
|
||||
value = config.get(key)
|
||||
parsed_value = default
|
||||
if isinstance(value, (int, float)):
|
||||
parsed_value = float(value)
|
||||
elif isinstance(value, str) and value:
|
||||
try:
|
||||
parsed_value = float(value)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
if parsed_value <= 0:
|
||||
return default
|
||||
|
||||
return parsed_value
|
||||
|
||||
self.model_url = _get_extra_config("model_url", "")
|
||||
self.model_deploy_strategy_name = _get_extra_config("model_deploy_strategy_name", "")
|
||||
self.scheduler_url = _get_extra_config("rfork_scheduler_url", "")
|
||||
self.seed_timeout_sec = _get_extra_config_float("rfork_seed_timeout_sec", 5.0)
|
||||
self.seed_key_separator = _get_extra_config("rfork_seed_key_separator", "$")
|
||||
|
||||
logger.info(
|
||||
"Initializing rfork with config: "
|
||||
"MODEL_URL=%s, MODEL_DEPLOY_STRATEGY_NAME=%s, "
|
||||
"SCHEDULER_URL=%s, SEED_TIMEOUT_SEC=%s, "
|
||||
"SEED_KEY_SEPARATOR=%s",
|
||||
self.model_url,
|
||||
self.model_deploy_strategy_name,
|
||||
self.scheduler_url,
|
||||
self.seed_timeout_sec,
|
||||
self.seed_key_separator,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _ensure_rfork_worker(self, vllm_config: VllmConfig) -> RForkWorker:
|
||||
rfork_worker = getattr(self.load_config, "rfork_worker", None)
|
||||
if rfork_worker is None:
|
||||
kv_transfer_config = vllm_config.kv_transfer_config
|
||||
disaggregation_mode = "kv_both" if kv_transfer_config is None else str(kv_transfer_config.kv_role)
|
||||
is_draft_model = (
|
||||
getattr(vllm_config.model_config, "runner_type", None) == "draft"
|
||||
or getattr(vllm_config.scheduler_config, "runner_type", None) == "draft"
|
||||
)
|
||||
device_id = torch.distributed.get_rank()
|
||||
self.load_config.rfork_worker = RForkWorker(
|
||||
disaggregation_mode=disaggregation_mode,
|
||||
node_rank=vllm_config.parallel_config.node_rank,
|
||||
tp_rank=get_tensor_model_parallel_rank(),
|
||||
device_id=device_id,
|
||||
scheduler_url=self.scheduler_url,
|
||||
model_url=self.model_url,
|
||||
model_deploy_strategy_name=self.model_deploy_strategy_name,
|
||||
seed_timeout_sec=self.seed_timeout_sec,
|
||||
seed_key_separator=self.seed_key_separator,
|
||||
is_draft_model=is_draft_model,
|
||||
)
|
||||
logger.info("RFork worker initialized, load_format=rfork")
|
||||
rfork_worker = self.load_config.rfork_worker
|
||||
return rfork_worker
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
prefix: str = "",
|
||||
) -> Module | None:
|
||||
device_config = vllm_config.device_config
|
||||
load_config = self.load_config
|
||||
load_device = device_config.device if load_config.device is None else load_config.device
|
||||
target_device = torch.device(load_device)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
need_del = False
|
||||
rfork_worker = self._ensure_rfork_worker(vllm_config)
|
||||
try:
|
||||
if not rfork_worker.is_seed_available():
|
||||
raise RuntimeError("seed is not available.")
|
||||
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config,
|
||||
model_config=model_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
need_del = True
|
||||
|
||||
weight_load_start_time = time.time()
|
||||
if not rfork_worker.pre_transfer(model):
|
||||
raise RuntimeError("pre_transfer failed.")
|
||||
if not rfork_worker.transfer(model):
|
||||
raise RuntimeError("transfer failed.")
|
||||
if not rfork_worker.post_transfer():
|
||||
raise RuntimeError("post_transfer failed.")
|
||||
logger.info(
|
||||
"Loading model weights took %.2f seconds",
|
||||
time.time() - weight_load_start_time,
|
||||
)
|
||||
|
||||
rfork_worker.start_seed_service(model)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
except Exception as e:
|
||||
logger.warning(f"RFork transfer failed: {e}, clean up and fall back to default loader")
|
||||
|
||||
rfork_worker.post_transfer()
|
||||
|
||||
if need_del:
|
||||
del model
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
self.load_config.load_format = "auto"
|
||||
self.load_config.model_loader_extra_config = {}
|
||||
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
|
||||
model = get_model(
|
||||
vllm_config=vllm_config,
|
||||
model_config=model_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
try:
|
||||
rfork_worker.start_seed_service(model)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Fallback model loaded, but start_seed_service failed: %s",
|
||||
e,
|
||||
)
|
||||
return model
|
||||
119
vllm_ascend/model_loader/rfork/rfork_worker.py
Normal file
119
vllm_ascend/model_loader/rfork/rfork_worker.py
Normal file
@@ -0,0 +1,119 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import threading
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from vllm_ascend.model_loader.rfork.seed_protocol import RForkSeedProtocol
|
||||
from vllm_ascend.model_loader.rfork.seed_server import start_rfork_server
|
||||
from vllm_ascend.model_loader.rfork.transfer_backend import (
|
||||
RForkTransferBackend,
|
||||
)
|
||||
|
||||
|
||||
class RForkWorker:
|
||||
def __init__(
|
||||
self,
|
||||
disaggregation_mode: str,
|
||||
node_rank: int,
|
||||
tp_rank: int,
|
||||
device_id: int,
|
||||
scheduler_url: str,
|
||||
model_url: str,
|
||||
model_deploy_strategy_name: str,
|
||||
seed_timeout_sec: float = 30.0,
|
||||
seed_key_separator: str = "$",
|
||||
is_draft_model: bool = False,
|
||||
):
|
||||
self.device_id = device_id
|
||||
self.rfork_seed = None
|
||||
self.transfer_backend = RForkTransferBackend()
|
||||
self.ready_to_start_seed_service = False
|
||||
self.seed_service_started = False
|
||||
self.seed_timeout_sec = seed_timeout_sec
|
||||
self.seed_protocol = RForkSeedProtocol(
|
||||
disaggregation_mode=disaggregation_mode,
|
||||
node_rank=node_rank,
|
||||
tp_rank=tp_rank,
|
||||
scheduler_url=scheduler_url,
|
||||
model_url=model_url,
|
||||
model_deploy_strategy_name=model_deploy_strategy_name,
|
||||
seed_key_separator=seed_key_separator,
|
||||
is_draft_worker=is_draft_model,
|
||||
)
|
||||
|
||||
def is_seed_available(self) -> bool:
|
||||
self.rfork_seed = self.seed_protocol.get_seed()
|
||||
return self.rfork_seed is not None
|
||||
|
||||
def pre_transfer(self, model) -> bool:
|
||||
try:
|
||||
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot pre_transfer."
|
||||
result = self.transfer_backend.register_memory_region(model)
|
||||
self.ready_to_start_seed_service = result
|
||||
return result
|
||||
except AssertionError as e:
|
||||
logger.exception("Pre-transfer failed: %s", e)
|
||||
return False
|
||||
|
||||
def transfer(self, model) -> bool:
|
||||
try:
|
||||
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot transfer."
|
||||
assert self.rfork_seed is not None, "rfork seed is None, cannot transfer."
|
||||
return self.transfer_backend.recv_from_source(
|
||||
model=model,
|
||||
seed_instance_ip=self.rfork_seed["seed_ip"],
|
||||
seed_instance_service_port=self.rfork_seed["seed_port"],
|
||||
local_seed_key=self.seed_protocol.get_local_seed_key(),
|
||||
)
|
||||
except AssertionError as e:
|
||||
logger.exception("Transfer failed: %s", e)
|
||||
return False
|
||||
|
||||
def post_transfer(self):
|
||||
if self.rfork_seed is None:
|
||||
logger.info("rfork seed is None, no need to release.")
|
||||
return True
|
||||
self.seed_protocol.release_seed(self.rfork_seed)
|
||||
return True
|
||||
|
||||
def start_seed_service(self, model):
|
||||
if self.seed_service_started:
|
||||
logger.info("Seed service already started, skipping.")
|
||||
return
|
||||
|
||||
if not self.ready_to_start_seed_service:
|
||||
if not self.pre_transfer(model):
|
||||
return
|
||||
|
||||
port = start_rfork_server(
|
||||
self.seed_protocol.get_local_seed_key(),
|
||||
(
|
||||
self.transfer_backend.rfork_transfer_engine_session_id,
|
||||
self.transfer_backend.rfork_transfer_engine_weights_info_dict,
|
||||
),
|
||||
health_timeout_sec=self.seed_timeout_sec,
|
||||
)
|
||||
if port > 0:
|
||||
self.rfork_heartbeat_thread = threading.Thread(
|
||||
target=self.seed_protocol.report_seed,
|
||||
args=(port,),
|
||||
daemon=True,
|
||||
name="RForkHeartbeat",
|
||||
)
|
||||
self.rfork_heartbeat_thread.start()
|
||||
self.seed_service_started = True
|
||||
208
vllm_ascend/model_loader/rfork/seed_protocol.py
Normal file
208
vllm_ascend/model_loader/rfork/seed_protocol.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
from urllib.error import HTTPError
|
||||
|
||||
import requests
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
REQUEST_TIMEOUT_SEC = 10.0
|
||||
HEARTBEAT_LOG_EVERY_N = 4
|
||||
|
||||
|
||||
def get_local_seed_key(
|
||||
disaggregation_mode: str,
|
||||
node_rank: int,
|
||||
tp_rank: int,
|
||||
model_url: str,
|
||||
model_deploy_strategy_name: str,
|
||||
seed_key_separator: str = "$",
|
||||
is_draft_worker: bool = False,
|
||||
) -> str:
|
||||
if not model_url or not model_deploy_strategy_name:
|
||||
raise RuntimeError(
|
||||
"RFork seed key is not set. Ensure model_loader_extra_config contains "
|
||||
"`model_url` and `model_deploy_strategy_name`."
|
||||
)
|
||||
|
||||
seed_key = f"{model_url}{seed_key_separator}{model_deploy_strategy_name}"
|
||||
key_suffix = f"{disaggregation_mode}{seed_key_separator}{node_rank}{seed_key_separator}{tp_rank}"
|
||||
if is_draft_worker:
|
||||
key_suffix += f"{seed_key_separator}draft"
|
||||
return f"{seed_key}{seed_key_separator}{key_suffix}"
|
||||
|
||||
|
||||
class RForkSeedProtocol:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
disaggregation_mode: str,
|
||||
node_rank: int,
|
||||
tp_rank: int,
|
||||
scheduler_url: str,
|
||||
model_url: str,
|
||||
model_deploy_strategy_name: str,
|
||||
seed_key_separator: str = "$",
|
||||
is_draft_worker: bool = False,
|
||||
):
|
||||
self.disaggregation_mode = disaggregation_mode
|
||||
self.node_rank = node_rank
|
||||
self.tp_rank = tp_rank
|
||||
self.scheduler_url = scheduler_url
|
||||
self.model_url = model_url
|
||||
self.model_deploy_strategy_name = model_deploy_strategy_name
|
||||
self.seed_key_separator = seed_key_separator
|
||||
self.is_draft_worker = is_draft_worker
|
||||
|
||||
self._local_seed_key = get_local_seed_key(
|
||||
disaggregation_mode=self.disaggregation_mode,
|
||||
node_rank=self.node_rank,
|
||||
tp_rank=self.tp_rank,
|
||||
model_url=self.model_url,
|
||||
model_deploy_strategy_name=self.model_deploy_strategy_name,
|
||||
seed_key_separator=self.seed_key_separator,
|
||||
is_draft_worker=self.is_draft_worker,
|
||||
)
|
||||
|
||||
def get_local_seed_key(self) -> str:
|
||||
return self._local_seed_key
|
||||
|
||||
@staticmethod
|
||||
def _request_timeout_sec() -> float:
|
||||
return REQUEST_TIMEOUT_SEC
|
||||
|
||||
def _ensure_scheduler_url_set(self) -> None:
|
||||
if not self.scheduler_url:
|
||||
raise RuntimeError("rfork_scheduler_url is not set. Cannot interact with the scheduler.")
|
||||
|
||||
def get_seed(self):
|
||||
try:
|
||||
self._ensure_scheduler_url_set()
|
||||
response = requests.get(
|
||||
f"{self.scheduler_url}/get_seed",
|
||||
headers={
|
||||
"SEED_KEY": self.get_local_seed_key(),
|
||||
},
|
||||
timeout=self._request_timeout_sec(),
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Failed to get seed from the planner, {response.status_code}")
|
||||
|
||||
seed_ip = response.headers.get("SEED_IP")
|
||||
seed_port = response.headers.get("SEED_PORT")
|
||||
user_id = response.headers.get("USER_ID")
|
||||
seed_rank = response.headers.get("SEED_RANK")
|
||||
logger.debug(
|
||||
"seed_ip: %s, seed_port: %s, user_id: %s, seed_rank: %s",
|
||||
seed_ip,
|
||||
seed_port,
|
||||
user_id,
|
||||
seed_rank,
|
||||
)
|
||||
return {
|
||||
"seed_ip": seed_ip,
|
||||
"seed_port": seed_port,
|
||||
"user_id": user_id,
|
||||
"seed_rank": seed_rank,
|
||||
}
|
||||
|
||||
except RuntimeError as e:
|
||||
logger.warning("get_seed from planner RuntimeError: %s", e)
|
||||
return None
|
||||
except HTTPError as e:
|
||||
logger.exception("get_seed from planner HTTPError: %s", e)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("get_seed from planner Exception: %s", e)
|
||||
return None
|
||||
|
||||
def release_seed(self, seed) -> bool:
|
||||
try:
|
||||
self._ensure_scheduler_url_set()
|
||||
user_id = seed["user_id"]
|
||||
seed_ip = seed["seed_ip"]
|
||||
seed_port = str(seed["seed_port"])
|
||||
seed_rank = str(seed["seed_rank"])
|
||||
|
||||
response = requests.post(
|
||||
f"{self.scheduler_url}/put_seed",
|
||||
headers={
|
||||
"SEED_IP": seed_ip,
|
||||
"SEED_PORT": seed_port,
|
||||
"USER_ID": user_id,
|
||||
"SEED_RANK": seed_rank,
|
||||
},
|
||||
timeout=self._request_timeout_sec(),
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(f"Failed to release seed to the planner, {response.status_code}")
|
||||
return True
|
||||
except RuntimeError as e:
|
||||
logger.exception("release_seed to planner RuntimeError: %s", e)
|
||||
return False
|
||||
except HTTPError as e:
|
||||
logger.exception("release_seed to planner HTTPError: %s", e)
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.exception("release_seed to planner Exception: %s", e)
|
||||
return False
|
||||
|
||||
def report_seed(self, port: int, sleep_interval: int = 30):
|
||||
heartbeat_idx = 0
|
||||
log_every_n = HEARTBEAT_LOG_EVERY_N
|
||||
try:
|
||||
self._ensure_scheduler_url_set()
|
||||
seed_ip = get_ip()
|
||||
seed_key = self.get_local_seed_key()
|
||||
except Exception as e:
|
||||
logger.exception("report_seed setup Exception: %s", e)
|
||||
return
|
||||
|
||||
while True:
|
||||
heartbeat_idx += 1
|
||||
result = False
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self.scheduler_url}/add_seed",
|
||||
headers={
|
||||
"SEED_KEY": seed_key,
|
||||
"SEED_IP": seed_ip,
|
||||
"SEED_PORT": str(port),
|
||||
"SEED_RANK": str(self.tp_rank),
|
||||
"SEED_REFCNT": str(0),
|
||||
},
|
||||
timeout=self._request_timeout_sec(),
|
||||
)
|
||||
if response.status_code == 200:
|
||||
result = True
|
||||
except HTTPError as e:
|
||||
logger.exception("report_seed to planner HTTPError: %s", e)
|
||||
except Exception as e:
|
||||
logger.exception("report_seed to planner Exception: %s", e)
|
||||
|
||||
# Keep heartbeat frequency unchanged, but reduce log noise.
|
||||
# Always print failures immediately; print success once every N times.
|
||||
if (not result) or (heartbeat_idx % log_every_n == 0):
|
||||
logger.info(
|
||||
"[rfork_heartbeat] report seed to planner result: %s (%d/%d)",
|
||||
result,
|
||||
heartbeat_idx % log_every_n if heartbeat_idx % log_every_n != 0 else log_every_n,
|
||||
log_every_n,
|
||||
)
|
||||
time.sleep(sleep_interval)
|
||||
126
vllm_ascend/model_loader/rfork/seed_server.py
Normal file
126
vllm_ascend/model_loader/rfork/seed_server.py
Normal file
@@ -0,0 +1,126 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import queue
|
||||
import socket
|
||||
import threading
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import requests
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.responses import Response
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def start_fastapi_server(
|
||||
port_queue: queue.Queue[int],
|
||||
local_seed_key,
|
||||
info,
|
||||
):
|
||||
logger.info("[RFork Seed] Preparing socket with dynamic port...")
|
||||
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("0.0.0.0", 0))
|
||||
_, port = sock.getsockname()
|
||||
logger.info("[RFork Seed] Assigned dynamic port: %s", port)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/get_rfork_transfer_engine_info")
|
||||
def get_rfork_transfer_engine_info(seed_key: str):
|
||||
if seed_key == local_seed_key:
|
||||
return {"rfork_transfer_engine_info": info}
|
||||
return {"rfork_transfer_engine_info": None}
|
||||
|
||||
@app.get("/rfork_fetch_seed")
|
||||
def rfork_fetch_seed():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/health_check_with_key")
|
||||
def health_check_with_key(seed_key: str):
|
||||
if seed_key == local_seed_key:
|
||||
return Response(status_code=HTTPStatus.OK)
|
||||
return Response(status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
config = uvicorn.Config(app, host=None, port=None, log_level="warning")
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
try:
|
||||
port_queue.put(port)
|
||||
except Exception as e:
|
||||
logger.error("[RFork Seed] Failed to send port via queue: %s", e)
|
||||
sock.close()
|
||||
return
|
||||
|
||||
logger.info("[RFork Seed] FastAPI server starting on port %s...", port)
|
||||
server.run(sockets=[sock])
|
||||
sock.close()
|
||||
|
||||
|
||||
def start_rfork_server(local_seed_key, rfork_transfer_engine_info, health_timeout_sec: float = 30.0) -> int:
|
||||
port_queue: queue.Queue[int] = queue.Queue()
|
||||
process = threading.Thread(
|
||||
target=start_fastapi_server,
|
||||
args=(port_queue, local_seed_key, rfork_transfer_engine_info),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
|
||||
try:
|
||||
port = port_queue.get(timeout=15)
|
||||
if port == -1:
|
||||
raise RuntimeError("Child process failed to start server")
|
||||
except Exception as e:
|
||||
logger.error("[RFork Seed] start server error: %s", e)
|
||||
return -1
|
||||
|
||||
deadline = time.time() + health_timeout_sec
|
||||
healthy = False
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
while time.time() < deadline:
|
||||
time.sleep(0.01)
|
||||
url = f"http://127.0.0.1:{port}/health_check_with_key"
|
||||
try:
|
||||
response = requests.get(
|
||||
url,
|
||||
params={"seed_key": local_seed_key},
|
||||
timeout=10,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
healthy = True
|
||||
break
|
||||
last_error = f"unexpected status code {response.status_code} from health check"
|
||||
except Exception as e:
|
||||
last_error = str(e)
|
||||
retry_count += 1
|
||||
if healthy:
|
||||
if retry_count > 1:
|
||||
logger.info(
|
||||
"[RFork Seed] health check passed after %d retries for port %s",
|
||||
retry_count - 1,
|
||||
port,
|
||||
)
|
||||
return port
|
||||
logger.error(
|
||||
"[RFork Seed] health check timed out after %.1fs for port %s, last error: %s",
|
||||
health_timeout_sec,
|
||||
port,
|
||||
last_error,
|
||||
)
|
||||
return -1
|
||||
212
vllm_ascend/model_loader/rfork/transfer_backend.py
Normal file
212
vllm_ascend/model_loader/rfork/transfer_backend.py
Normal file
@@ -0,0 +1,212 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.network_utils import get_ip, get_open_port, join_host_port
|
||||
|
||||
|
||||
class RForkTransferBackend:
|
||||
def __init__(self):
|
||||
self.rfork_transfer_engine: Any | None = None
|
||||
self.rfork_transfer_engine_session_id = None
|
||||
self.rfork_transfer_engine_weights_info_dict = None
|
||||
self.registered_weight_blocks = []
|
||||
self._is_initialized = False
|
||||
self.init_transfer_engine()
|
||||
|
||||
def init_transfer_engine(self):
|
||||
try:
|
||||
from yr.datasystem import TransferEngine # type: ignore[import-not-found]
|
||||
except ImportError as e:
|
||||
raise ImportError("Please install @yuanrong-datasystem/transfer_engine first.") from e
|
||||
|
||||
transfer_engine = TransferEngine()
|
||||
local_hostname = join_host_port(get_ip(), get_open_port())
|
||||
ret = transfer_engine.initialize(local_hostname, "ascend", f"npu:{torch.npu.current_device()}")
|
||||
if ret.is_error():
|
||||
raise RuntimeError(
|
||||
"TransferEngine initialization failed: "
|
||||
f"initialize({local_hostname}, ascend"
|
||||
f"npu:{int(torch.npu.current_device())}) -> {ret.to_string()}"
|
||||
)
|
||||
|
||||
self.rfork_transfer_engine = transfer_engine
|
||||
self.rfork_transfer_engine_session_id = local_hostname
|
||||
self._is_initialized = True
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._is_initialized
|
||||
|
||||
def _get_transfer_engine(self) -> Any:
|
||||
if self.rfork_transfer_engine is None:
|
||||
raise RuntimeError("TransferEngine is not initialized.")
|
||||
return self.rfork_transfer_engine
|
||||
|
||||
def register_memory_region(self, model):
|
||||
transfer_engine = self._get_transfer_engine()
|
||||
start_reg_mr_tic = time.time()
|
||||
|
||||
weight_mr_dict = {}
|
||||
weight_addr_set = set()
|
||||
for name, weight in model.named_parameters():
|
||||
weight_mr_dict[name] = (
|
||||
weight.data_ptr(),
|
||||
weight.numel(),
|
||||
weight.element_size(),
|
||||
)
|
||||
weight_addr_set.add(weight.data_ptr())
|
||||
|
||||
memory_snapshot = torch.npu.memory.memory_snapshot()
|
||||
weight_blocks_for_reg_mr = []
|
||||
for segment in memory_snapshot:
|
||||
current_weight_block = None
|
||||
for block in segment.get("blocks", []):
|
||||
address = block.get("address", -1)
|
||||
size = block.get("size", -1)
|
||||
state = block.get("state", "")
|
||||
if address < 0 or size < 0 or state == "":
|
||||
continue
|
||||
if state == "active_allocated" and address in weight_addr_set:
|
||||
if current_weight_block is None:
|
||||
current_weight_block = (address, size)
|
||||
elif current_weight_block[0] + current_weight_block[1] == address:
|
||||
current_weight_block = (
|
||||
current_weight_block[0],
|
||||
current_weight_block[1] + size,
|
||||
)
|
||||
else:
|
||||
weight_blocks_for_reg_mr.append(current_weight_block)
|
||||
current_weight_block = (address, size)
|
||||
if current_weight_block is not None:
|
||||
weight_blocks_for_reg_mr.append(current_weight_block)
|
||||
|
||||
addresses, sizes = zip(*weight_blocks_for_reg_mr) if weight_blocks_for_reg_mr else ((), ())
|
||||
ret = transfer_engine.batch_register_memory(addresses, sizes)
|
||||
if ret.is_error():
|
||||
logger.error(
|
||||
"batch_register_memory failed for %d blocks, ret: %s",
|
||||
len(weight_blocks_for_reg_mr),
|
||||
ret.to_string(),
|
||||
)
|
||||
return False
|
||||
|
||||
self.rfork_transfer_engine_weights_info_dict = weight_mr_dict
|
||||
self.registered_weight_blocks = weight_blocks_for_reg_mr
|
||||
|
||||
logger.info(
|
||||
"register_memory_region time: %.4fs",
|
||||
time.time() - start_reg_mr_tic,
|
||||
)
|
||||
return True
|
||||
|
||||
def unregister_memory_region(self) -> bool:
|
||||
transfer_engine = self._get_transfer_engine()
|
||||
start_unreg_mr_tic = time.time()
|
||||
ret = transfer_engine.batch_unregister_memory([address for address, _ in self.registered_weight_blocks])
|
||||
if ret.is_error():
|
||||
logger.error(
|
||||
"batch_unregister_memory failed for %d blocks, ret: %s",
|
||||
len(self.registered_weight_blocks),
|
||||
ret.to_string(),
|
||||
)
|
||||
return False
|
||||
self.rfork_transfer_engine_weights_info_dict = None
|
||||
self.registered_weight_blocks = []
|
||||
logger.info(
|
||||
"unregister_memory_region time: %.4fs",
|
||||
time.time() - start_unreg_mr_tic,
|
||||
)
|
||||
return True
|
||||
|
||||
def recv_from_source(
|
||||
self,
|
||||
model,
|
||||
seed_instance_ip,
|
||||
seed_instance_service_port,
|
||||
local_seed_key,
|
||||
):
|
||||
transfer_engine = self._get_transfer_engine()
|
||||
seed_url = f"http://{seed_instance_ip}:{seed_instance_service_port}"
|
||||
seed_session_id, seed_weight_info = get_remote_instance_transfer_engine_info(seed_url, local_seed_key)
|
||||
if seed_session_id is None or seed_weight_info is None:
|
||||
logger.error("Cannot get transfer engine session or weight info.")
|
||||
return False
|
||||
|
||||
seed_ptr_list = []
|
||||
client_ptr_list = []
|
||||
client_len_list = []
|
||||
for name, tensor in model.named_parameters():
|
||||
weight_info = seed_weight_info.get(name, None)
|
||||
if weight_info is None:
|
||||
logger.error("Cannot find weight info for %s.", name)
|
||||
return False
|
||||
|
||||
seed_ptr, seed_len, seed_size = weight_info
|
||||
if seed_len != tensor.numel() or seed_size != tensor.element_size():
|
||||
logger.error(
|
||||
"Weight info mismatch for %s, expected (%s, %s), got (%s, %s)",
|
||||
name,
|
||||
seed_len,
|
||||
seed_size,
|
||||
tensor.numel(),
|
||||
tensor.element_size(),
|
||||
)
|
||||
return False
|
||||
|
||||
seed_ptr_list.append(seed_ptr)
|
||||
client_ptr_list.append(tensor.data_ptr())
|
||||
client_len_list.append(tensor.numel() * tensor.element_size())
|
||||
|
||||
start_transfer_tic = time.time()
|
||||
ret = transfer_engine.batch_transfer_sync_read(
|
||||
seed_session_id,
|
||||
client_ptr_list,
|
||||
seed_ptr_list,
|
||||
client_len_list,
|
||||
)
|
||||
if ret.is_error():
|
||||
logger.error("Failed to transfer weights from remote instance, ret=%s", ret.to_string())
|
||||
return False
|
||||
|
||||
logger.info("transfer weights time: %.4fs", time.time() - start_transfer_tic)
|
||||
return True
|
||||
|
||||
|
||||
def get_remote_instance_transfer_engine_info(seed_url: str, local_seed_key: str):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{seed_url}/get_rfork_transfer_engine_info",
|
||||
params={"seed_key": local_seed_key},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.error("request.get failed: %s", response.status_code)
|
||||
return None, None
|
||||
|
||||
data = response.json()
|
||||
info = data.get("rfork_transfer_engine_info", None)
|
||||
if info is not None and isinstance(info, list) and len(info) == 2:
|
||||
return info[0], info[1]
|
||||
|
||||
logger.error("Failed to get `rfork_transfer_engine_info` in response.")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error("Exception: %s", e)
|
||||
return None, None
|
||||
Reference in New Issue
Block a user