[ModelLoader][Feature] Add rfork support for fast model loading (#7392)

### What this PR does / why we need it?
Support an new load format: RFORK

For implementation details of this feature, please refer to #7441


### Does this PR introduce _any_ user-facing change?

add an new options for load-format: rfork

e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```

### How was this patch tested?

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e

Signed-off-by: Marck <1412354149@qq.com>
This commit is contained in:
Marck
2026-03-25 16:40:30 +08:00
committed by GitHub
parent 6ddfc41312
commit 17da96658f
11 changed files with 1510 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def register_rforkloader() -> None:
"""Register the RFork model loader plugin."""
from .rfork_loader import RForkModelLoader # noqa: F401

View File

@@ -0,0 +1,188 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import time
import torch
import torch.nn as nn
from torch.nn import Module
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import logger
from vllm.model_executor.model_loader import register_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm_ascend.model_loader.rfork.rfork_worker import RForkWorker
@register_model_loader("rfork")
class RForkModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
config = load_config.model_loader_extra_config
if not isinstance(config, dict):
raise RuntimeError("RFork requires --model-loader-extra-config to be a JSON object.")
def _get_extra_config(key: str, default: str = "") -> str:
value = config.get(key)
return value if isinstance(value, str) and value else default
def _get_extra_config_float(key: str, default: float) -> float:
value = config.get(key)
parsed_value = default
if isinstance(value, (int, float)):
parsed_value = float(value)
elif isinstance(value, str) and value:
try:
parsed_value = float(value)
except ValueError:
return default
if parsed_value <= 0:
return default
return parsed_value
self.model_url = _get_extra_config("model_url", "")
self.model_deploy_strategy_name = _get_extra_config("model_deploy_strategy_name", "")
self.scheduler_url = _get_extra_config("rfork_scheduler_url", "")
self.seed_timeout_sec = _get_extra_config_float("rfork_seed_timeout_sec", 5.0)
self.seed_key_separator = _get_extra_config("rfork_seed_key_separator", "$")
logger.info(
"Initializing rfork with config: "
"MODEL_URL=%s, MODEL_DEPLOY_STRATEGY_NAME=%s, "
"SCHEDULER_URL=%s, SEED_TIMEOUT_SEC=%s, "
"SEED_KEY_SEPARATOR=%s",
self.model_url,
self.model_deploy_strategy_name,
self.scheduler_url,
self.seed_timeout_sec,
self.seed_key_separator,
)
def download_model(self, model_config: ModelConfig) -> None:
raise NotImplementedError
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
raise NotImplementedError
def _ensure_rfork_worker(self, vllm_config: VllmConfig) -> RForkWorker:
rfork_worker = getattr(self.load_config, "rfork_worker", None)
if rfork_worker is None:
kv_transfer_config = vllm_config.kv_transfer_config
disaggregation_mode = "kv_both" if kv_transfer_config is None else str(kv_transfer_config.kv_role)
is_draft_model = (
getattr(vllm_config.model_config, "runner_type", None) == "draft"
or getattr(vllm_config.scheduler_config, "runner_type", None) == "draft"
)
device_id = torch.distributed.get_rank()
self.load_config.rfork_worker = RForkWorker(
disaggregation_mode=disaggregation_mode,
node_rank=vllm_config.parallel_config.node_rank,
tp_rank=get_tensor_model_parallel_rank(),
device_id=device_id,
scheduler_url=self.scheduler_url,
model_url=self.model_url,
model_deploy_strategy_name=self.model_deploy_strategy_name,
seed_timeout_sec=self.seed_timeout_sec,
seed_key_separator=self.seed_key_separator,
is_draft_model=is_draft_model,
)
logger.info("RFork worker initialized, load_format=rfork")
rfork_worker = self.load_config.rfork_worker
return rfork_worker
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
prefix: str = "",
) -> Module | None:
device_config = vllm_config.device_config
load_config = self.load_config
load_device = device_config.device if load_config.device is None else load_config.device
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
need_del = False
rfork_worker = self._ensure_rfork_worker(vllm_config)
try:
if not rfork_worker.is_seed_available():
raise RuntimeError("seed is not available.")
with target_device:
model = initialize_model(
vllm_config=vllm_config,
model_config=model_config,
prefix=prefix,
)
need_del = True
weight_load_start_time = time.time()
if not rfork_worker.pre_transfer(model):
raise RuntimeError("pre_transfer failed.")
if not rfork_worker.transfer(model):
raise RuntimeError("transfer failed.")
if not rfork_worker.post_transfer():
raise RuntimeError("post_transfer failed.")
logger.info(
"Loading model weights took %.2f seconds",
time.time() - weight_load_start_time,
)
rfork_worker.start_seed_service(model)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
except Exception as e:
logger.warning(f"RFork transfer failed: {e}, clean up and fall back to default loader")
rfork_worker.post_transfer()
if need_del:
del model
gc.collect()
torch.npu.empty_cache()
for _ in range(3):
gc.collect()
torch.npu.empty_cache()
self.load_config.load_format = "auto"
self.load_config.model_loader_extra_config = {}
from vllm.model_executor.model_loader import get_model
model = get_model(
vllm_config=vllm_config,
model_config=model_config,
prefix=prefix,
)
try:
rfork_worker.start_seed_service(model)
except Exception as e:
logger.warning(
"Fallback model loaded, but start_seed_service failed: %s",
e,
)
return model

View File

@@ -0,0 +1,119 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import threading
from vllm.logger import logger
from vllm_ascend.model_loader.rfork.seed_protocol import RForkSeedProtocol
from vllm_ascend.model_loader.rfork.seed_server import start_rfork_server
from vllm_ascend.model_loader.rfork.transfer_backend import (
RForkTransferBackend,
)
class RForkWorker:
def __init__(
self,
disaggregation_mode: str,
node_rank: int,
tp_rank: int,
device_id: int,
scheduler_url: str,
model_url: str,
model_deploy_strategy_name: str,
seed_timeout_sec: float = 30.0,
seed_key_separator: str = "$",
is_draft_model: bool = False,
):
self.device_id = device_id
self.rfork_seed = None
self.transfer_backend = RForkTransferBackend()
self.ready_to_start_seed_service = False
self.seed_service_started = False
self.seed_timeout_sec = seed_timeout_sec
self.seed_protocol = RForkSeedProtocol(
disaggregation_mode=disaggregation_mode,
node_rank=node_rank,
tp_rank=tp_rank,
scheduler_url=scheduler_url,
model_url=model_url,
model_deploy_strategy_name=model_deploy_strategy_name,
seed_key_separator=seed_key_separator,
is_draft_worker=is_draft_model,
)
def is_seed_available(self) -> bool:
self.rfork_seed = self.seed_protocol.get_seed()
return self.rfork_seed is not None
def pre_transfer(self, model) -> bool:
try:
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot pre_transfer."
result = self.transfer_backend.register_memory_region(model)
self.ready_to_start_seed_service = result
return result
except AssertionError as e:
logger.exception("Pre-transfer failed: %s", e)
return False
def transfer(self, model) -> bool:
try:
assert self.transfer_backend.is_initialized(), "transfer_backend is not initialized, cannot transfer."
assert self.rfork_seed is not None, "rfork seed is None, cannot transfer."
return self.transfer_backend.recv_from_source(
model=model,
seed_instance_ip=self.rfork_seed["seed_ip"],
seed_instance_service_port=self.rfork_seed["seed_port"],
local_seed_key=self.seed_protocol.get_local_seed_key(),
)
except AssertionError as e:
logger.exception("Transfer failed: %s", e)
return False
def post_transfer(self):
if self.rfork_seed is None:
logger.info("rfork seed is None, no need to release.")
return True
self.seed_protocol.release_seed(self.rfork_seed)
return True
def start_seed_service(self, model):
if self.seed_service_started:
logger.info("Seed service already started, skipping.")
return
if not self.ready_to_start_seed_service:
if not self.pre_transfer(model):
return
port = start_rfork_server(
self.seed_protocol.get_local_seed_key(),
(
self.transfer_backend.rfork_transfer_engine_session_id,
self.transfer_backend.rfork_transfer_engine_weights_info_dict,
),
health_timeout_sec=self.seed_timeout_sec,
)
if port > 0:
self.rfork_heartbeat_thread = threading.Thread(
target=self.seed_protocol.report_seed,
args=(port,),
daemon=True,
name="RForkHeartbeat",
)
self.rfork_heartbeat_thread.start()
self.seed_service_started = True

View File

@@ -0,0 +1,208 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from urllib.error import HTTPError
import requests
from vllm.logger import logger
from vllm.utils.network_utils import get_ip
REQUEST_TIMEOUT_SEC = 10.0
HEARTBEAT_LOG_EVERY_N = 4
def get_local_seed_key(
disaggregation_mode: str,
node_rank: int,
tp_rank: int,
model_url: str,
model_deploy_strategy_name: str,
seed_key_separator: str = "$",
is_draft_worker: bool = False,
) -> str:
if not model_url or not model_deploy_strategy_name:
raise RuntimeError(
"RFork seed key is not set. Ensure model_loader_extra_config contains "
"`model_url` and `model_deploy_strategy_name`."
)
seed_key = f"{model_url}{seed_key_separator}{model_deploy_strategy_name}"
key_suffix = f"{disaggregation_mode}{seed_key_separator}{node_rank}{seed_key_separator}{tp_rank}"
if is_draft_worker:
key_suffix += f"{seed_key_separator}draft"
return f"{seed_key}{seed_key_separator}{key_suffix}"
class RForkSeedProtocol:
def __init__(
self,
*,
disaggregation_mode: str,
node_rank: int,
tp_rank: int,
scheduler_url: str,
model_url: str,
model_deploy_strategy_name: str,
seed_key_separator: str = "$",
is_draft_worker: bool = False,
):
self.disaggregation_mode = disaggregation_mode
self.node_rank = node_rank
self.tp_rank = tp_rank
self.scheduler_url = scheduler_url
self.model_url = model_url
self.model_deploy_strategy_name = model_deploy_strategy_name
self.seed_key_separator = seed_key_separator
self.is_draft_worker = is_draft_worker
self._local_seed_key = get_local_seed_key(
disaggregation_mode=self.disaggregation_mode,
node_rank=self.node_rank,
tp_rank=self.tp_rank,
model_url=self.model_url,
model_deploy_strategy_name=self.model_deploy_strategy_name,
seed_key_separator=self.seed_key_separator,
is_draft_worker=self.is_draft_worker,
)
def get_local_seed_key(self) -> str:
return self._local_seed_key
@staticmethod
def _request_timeout_sec() -> float:
return REQUEST_TIMEOUT_SEC
def _ensure_scheduler_url_set(self) -> None:
if not self.scheduler_url:
raise RuntimeError("rfork_scheduler_url is not set. Cannot interact with the scheduler.")
def get_seed(self):
try:
self._ensure_scheduler_url_set()
response = requests.get(
f"{self.scheduler_url}/get_seed",
headers={
"SEED_KEY": self.get_local_seed_key(),
},
timeout=self._request_timeout_sec(),
)
if response.status_code != 200:
raise RuntimeError(f"Failed to get seed from the planner, {response.status_code}")
seed_ip = response.headers.get("SEED_IP")
seed_port = response.headers.get("SEED_PORT")
user_id = response.headers.get("USER_ID")
seed_rank = response.headers.get("SEED_RANK")
logger.debug(
"seed_ip: %s, seed_port: %s, user_id: %s, seed_rank: %s",
seed_ip,
seed_port,
user_id,
seed_rank,
)
return {
"seed_ip": seed_ip,
"seed_port": seed_port,
"user_id": user_id,
"seed_rank": seed_rank,
}
except RuntimeError as e:
logger.warning("get_seed from planner RuntimeError: %s", e)
return None
except HTTPError as e:
logger.exception("get_seed from planner HTTPError: %s", e)
return None
except Exception as e:
logger.exception("get_seed from planner Exception: %s", e)
return None
def release_seed(self, seed) -> bool:
try:
self._ensure_scheduler_url_set()
user_id = seed["user_id"]
seed_ip = seed["seed_ip"]
seed_port = str(seed["seed_port"])
seed_rank = str(seed["seed_rank"])
response = requests.post(
f"{self.scheduler_url}/put_seed",
headers={
"SEED_IP": seed_ip,
"SEED_PORT": seed_port,
"USER_ID": user_id,
"SEED_RANK": seed_rank,
},
timeout=self._request_timeout_sec(),
)
if response.status_code != 200:
raise RuntimeError(f"Failed to release seed to the planner, {response.status_code}")
return True
except RuntimeError as e:
logger.exception("release_seed to planner RuntimeError: %s", e)
return False
except HTTPError as e:
logger.exception("release_seed to planner HTTPError: %s", e)
return False
except Exception as e:
logger.exception("release_seed to planner Exception: %s", e)
return False
def report_seed(self, port: int, sleep_interval: int = 30):
heartbeat_idx = 0
log_every_n = HEARTBEAT_LOG_EVERY_N
try:
self._ensure_scheduler_url_set()
seed_ip = get_ip()
seed_key = self.get_local_seed_key()
except Exception as e:
logger.exception("report_seed setup Exception: %s", e)
return
while True:
heartbeat_idx += 1
result = False
try:
response = requests.post(
f"{self.scheduler_url}/add_seed",
headers={
"SEED_KEY": seed_key,
"SEED_IP": seed_ip,
"SEED_PORT": str(port),
"SEED_RANK": str(self.tp_rank),
"SEED_REFCNT": str(0),
},
timeout=self._request_timeout_sec(),
)
if response.status_code == 200:
result = True
except HTTPError as e:
logger.exception("report_seed to planner HTTPError: %s", e)
except Exception as e:
logger.exception("report_seed to planner Exception: %s", e)
# Keep heartbeat frequency unchanged, but reduce log noise.
# Always print failures immediately; print success once every N times.
if (not result) or (heartbeat_idx % log_every_n == 0):
logger.info(
"[rfork_heartbeat] report seed to planner result: %s (%d/%d)",
result,
heartbeat_idx % log_every_n if heartbeat_idx % log_every_n != 0 else log_every_n,
log_every_n,
)
time.sleep(sleep_interval)

View File

@@ -0,0 +1,126 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import queue
import socket
import threading
import time
from http import HTTPStatus
import requests
import uvicorn
from fastapi import FastAPI
from fastapi.responses import Response
from vllm.logger import logger
def start_fastapi_server(
port_queue: queue.Queue[int],
local_seed_key,
info,
):
logger.info("[RFork Seed] Preparing socket with dynamic port...")
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.bind(("0.0.0.0", 0))
_, port = sock.getsockname()
logger.info("[RFork Seed] Assigned dynamic port: %s", port)
app = FastAPI()
@app.get("/get_rfork_transfer_engine_info")
def get_rfork_transfer_engine_info(seed_key: str):
if seed_key == local_seed_key:
return {"rfork_transfer_engine_info": info}
return {"rfork_transfer_engine_info": None}
@app.get("/rfork_fetch_seed")
def rfork_fetch_seed():
return {"status": "ok"}
@app.get("/health_check_with_key")
def health_check_with_key(seed_key: str):
if seed_key == local_seed_key:
return Response(status_code=HTTPStatus.OK)
return Response(status_code=HTTPStatus.BAD_REQUEST)
config = uvicorn.Config(app, host=None, port=None, log_level="warning")
server = uvicorn.Server(config)
try:
port_queue.put(port)
except Exception as e:
logger.error("[RFork Seed] Failed to send port via queue: %s", e)
sock.close()
return
logger.info("[RFork Seed] FastAPI server starting on port %s...", port)
server.run(sockets=[sock])
sock.close()
def start_rfork_server(local_seed_key, rfork_transfer_engine_info, health_timeout_sec: float = 30.0) -> int:
port_queue: queue.Queue[int] = queue.Queue()
process = threading.Thread(
target=start_fastapi_server,
args=(port_queue, local_seed_key, rfork_transfer_engine_info),
daemon=True,
)
process.start()
try:
port = port_queue.get(timeout=15)
if port == -1:
raise RuntimeError("Child process failed to start server")
except Exception as e:
logger.error("[RFork Seed] start server error: %s", e)
return -1
deadline = time.time() + health_timeout_sec
healthy = False
retry_count = 0
last_error = None
while time.time() < deadline:
time.sleep(0.01)
url = f"http://127.0.0.1:{port}/health_check_with_key"
try:
response = requests.get(
url,
params={"seed_key": local_seed_key},
timeout=10,
)
if response.status_code == 200:
healthy = True
break
last_error = f"unexpected status code {response.status_code} from health check"
except Exception as e:
last_error = str(e)
retry_count += 1
if healthy:
if retry_count > 1:
logger.info(
"[RFork Seed] health check passed after %d retries for port %s",
retry_count - 1,
port,
)
return port
logger.error(
"[RFork Seed] health check timed out after %.1fs for port %s, last error: %s",
health_timeout_sec,
port,
last_error,
)
return -1

View File

@@ -0,0 +1,212 @@
#
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from typing import Any
import requests
import torch
from vllm.logger import logger
from vllm.utils.network_utils import get_ip, get_open_port, join_host_port
class RForkTransferBackend:
def __init__(self):
self.rfork_transfer_engine: Any | None = None
self.rfork_transfer_engine_session_id = None
self.rfork_transfer_engine_weights_info_dict = None
self.registered_weight_blocks = []
self._is_initialized = False
self.init_transfer_engine()
def init_transfer_engine(self):
try:
from yr.datasystem import TransferEngine # type: ignore[import-not-found]
except ImportError as e:
raise ImportError("Please install @yuanrong-datasystem/transfer_engine first.") from e
transfer_engine = TransferEngine()
local_hostname = join_host_port(get_ip(), get_open_port())
ret = transfer_engine.initialize(local_hostname, "ascend", f"npu:{torch.npu.current_device()}")
if ret.is_error():
raise RuntimeError(
"TransferEngine initialization failed: "
f"initialize({local_hostname}, ascend"
f"npu:{int(torch.npu.current_device())}) -> {ret.to_string()}"
)
self.rfork_transfer_engine = transfer_engine
self.rfork_transfer_engine_session_id = local_hostname
self._is_initialized = True
def is_initialized(self) -> bool:
return self._is_initialized
def _get_transfer_engine(self) -> Any:
if self.rfork_transfer_engine is None:
raise RuntimeError("TransferEngine is not initialized.")
return self.rfork_transfer_engine
def register_memory_region(self, model):
transfer_engine = self._get_transfer_engine()
start_reg_mr_tic = time.time()
weight_mr_dict = {}
weight_addr_set = set()
for name, weight in model.named_parameters():
weight_mr_dict[name] = (
weight.data_ptr(),
weight.numel(),
weight.element_size(),
)
weight_addr_set.add(weight.data_ptr())
memory_snapshot = torch.npu.memory.memory_snapshot()
weight_blocks_for_reg_mr = []
for segment in memory_snapshot:
current_weight_block = None
for block in segment.get("blocks", []):
address = block.get("address", -1)
size = block.get("size", -1)
state = block.get("state", "")
if address < 0 or size < 0 or state == "":
continue
if state == "active_allocated" and address in weight_addr_set:
if current_weight_block is None:
current_weight_block = (address, size)
elif current_weight_block[0] + current_weight_block[1] == address:
current_weight_block = (
current_weight_block[0],
current_weight_block[1] + size,
)
else:
weight_blocks_for_reg_mr.append(current_weight_block)
current_weight_block = (address, size)
if current_weight_block is not None:
weight_blocks_for_reg_mr.append(current_weight_block)
addresses, sizes = zip(*weight_blocks_for_reg_mr) if weight_blocks_for_reg_mr else ((), ())
ret = transfer_engine.batch_register_memory(addresses, sizes)
if ret.is_error():
logger.error(
"batch_register_memory failed for %d blocks, ret: %s",
len(weight_blocks_for_reg_mr),
ret.to_string(),
)
return False
self.rfork_transfer_engine_weights_info_dict = weight_mr_dict
self.registered_weight_blocks = weight_blocks_for_reg_mr
logger.info(
"register_memory_region time: %.4fs",
time.time() - start_reg_mr_tic,
)
return True
def unregister_memory_region(self) -> bool:
transfer_engine = self._get_transfer_engine()
start_unreg_mr_tic = time.time()
ret = transfer_engine.batch_unregister_memory([address for address, _ in self.registered_weight_blocks])
if ret.is_error():
logger.error(
"batch_unregister_memory failed for %d blocks, ret: %s",
len(self.registered_weight_blocks),
ret.to_string(),
)
return False
self.rfork_transfer_engine_weights_info_dict = None
self.registered_weight_blocks = []
logger.info(
"unregister_memory_region time: %.4fs",
time.time() - start_unreg_mr_tic,
)
return True
def recv_from_source(
self,
model,
seed_instance_ip,
seed_instance_service_port,
local_seed_key,
):
transfer_engine = self._get_transfer_engine()
seed_url = f"http://{seed_instance_ip}:{seed_instance_service_port}"
seed_session_id, seed_weight_info = get_remote_instance_transfer_engine_info(seed_url, local_seed_key)
if seed_session_id is None or seed_weight_info is None:
logger.error("Cannot get transfer engine session or weight info.")
return False
seed_ptr_list = []
client_ptr_list = []
client_len_list = []
for name, tensor in model.named_parameters():
weight_info = seed_weight_info.get(name, None)
if weight_info is None:
logger.error("Cannot find weight info for %s.", name)
return False
seed_ptr, seed_len, seed_size = weight_info
if seed_len != tensor.numel() or seed_size != tensor.element_size():
logger.error(
"Weight info mismatch for %s, expected (%s, %s), got (%s, %s)",
name,
seed_len,
seed_size,
tensor.numel(),
tensor.element_size(),
)
return False
seed_ptr_list.append(seed_ptr)
client_ptr_list.append(tensor.data_ptr())
client_len_list.append(tensor.numel() * tensor.element_size())
start_transfer_tic = time.time()
ret = transfer_engine.batch_transfer_sync_read(
seed_session_id,
client_ptr_list,
seed_ptr_list,
client_len_list,
)
if ret.is_error():
logger.error("Failed to transfer weights from remote instance, ret=%s", ret.to_string())
return False
logger.info("transfer weights time: %.4fs", time.time() - start_transfer_tic)
return True
def get_remote_instance_transfer_engine_info(seed_url: str, local_seed_key: str):
try:
response = requests.get(
f"{seed_url}/get_rfork_transfer_engine_info",
params={"seed_key": local_seed_key},
)
if response.status_code != 200:
logger.error("request.get failed: %s", response.status_code)
return None, None
data = response.json()
info = data.get("rfork_transfer_engine_info", None)
if info is not None and isinstance(info, list) and len(info) == 2:
return info[0], info[1]
logger.error("Failed to get `rfork_transfer_engine_info` in response.")
return None, None
except Exception as e:
logger.error("Exception: %s", e)
return None, None