### What this PR does / why we need it?
Support an new load format: RFORK
For implementation details of this feature, please refer to #7441
### Does this PR introduce _any_ user-facing change?
add an new options for load-format: rfork
e.g.
```bash
vllm serve /workspace/models/Qwen3-8B --load-format rfork
```
### How was this patch tested?
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
Signed-off-by: Marck <1412354149@qq.com>
213 lines
8.1 KiB
Python
213 lines
8.1 KiB
Python
#
|
|
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import time
|
|
from typing import Any
|
|
|
|
import requests
|
|
import torch
|
|
from vllm.logger import logger
|
|
from vllm.utils.network_utils import get_ip, get_open_port, join_host_port
|
|
|
|
|
|
class RForkTransferBackend:
|
|
def __init__(self):
|
|
self.rfork_transfer_engine: Any | None = None
|
|
self.rfork_transfer_engine_session_id = None
|
|
self.rfork_transfer_engine_weights_info_dict = None
|
|
self.registered_weight_blocks = []
|
|
self._is_initialized = False
|
|
self.init_transfer_engine()
|
|
|
|
def init_transfer_engine(self):
|
|
try:
|
|
from yr.datasystem import TransferEngine # type: ignore[import-not-found]
|
|
except ImportError as e:
|
|
raise ImportError("Please install @yuanrong-datasystem/transfer_engine first.") from e
|
|
|
|
transfer_engine = TransferEngine()
|
|
local_hostname = join_host_port(get_ip(), get_open_port())
|
|
ret = transfer_engine.initialize(local_hostname, "ascend", f"npu:{torch.npu.current_device()}")
|
|
if ret.is_error():
|
|
raise RuntimeError(
|
|
"TransferEngine initialization failed: "
|
|
f"initialize({local_hostname}, ascend"
|
|
f"npu:{int(torch.npu.current_device())}) -> {ret.to_string()}"
|
|
)
|
|
|
|
self.rfork_transfer_engine = transfer_engine
|
|
self.rfork_transfer_engine_session_id = local_hostname
|
|
self._is_initialized = True
|
|
|
|
def is_initialized(self) -> bool:
|
|
return self._is_initialized
|
|
|
|
def _get_transfer_engine(self) -> Any:
|
|
if self.rfork_transfer_engine is None:
|
|
raise RuntimeError("TransferEngine is not initialized.")
|
|
return self.rfork_transfer_engine
|
|
|
|
def register_memory_region(self, model):
|
|
transfer_engine = self._get_transfer_engine()
|
|
start_reg_mr_tic = time.time()
|
|
|
|
weight_mr_dict = {}
|
|
weight_addr_set = set()
|
|
for name, weight in model.named_parameters():
|
|
weight_mr_dict[name] = (
|
|
weight.data_ptr(),
|
|
weight.numel(),
|
|
weight.element_size(),
|
|
)
|
|
weight_addr_set.add(weight.data_ptr())
|
|
|
|
memory_snapshot = torch.npu.memory.memory_snapshot()
|
|
weight_blocks_for_reg_mr = []
|
|
for segment in memory_snapshot:
|
|
current_weight_block = None
|
|
for block in segment.get("blocks", []):
|
|
address = block.get("address", -1)
|
|
size = block.get("size", -1)
|
|
state = block.get("state", "")
|
|
if address < 0 or size < 0 or state == "":
|
|
continue
|
|
if state == "active_allocated" and address in weight_addr_set:
|
|
if current_weight_block is None:
|
|
current_weight_block = (address, size)
|
|
elif current_weight_block[0] + current_weight_block[1] == address:
|
|
current_weight_block = (
|
|
current_weight_block[0],
|
|
current_weight_block[1] + size,
|
|
)
|
|
else:
|
|
weight_blocks_for_reg_mr.append(current_weight_block)
|
|
current_weight_block = (address, size)
|
|
if current_weight_block is not None:
|
|
weight_blocks_for_reg_mr.append(current_weight_block)
|
|
|
|
addresses, sizes = zip(*weight_blocks_for_reg_mr) if weight_blocks_for_reg_mr else ((), ())
|
|
ret = transfer_engine.batch_register_memory(addresses, sizes)
|
|
if ret.is_error():
|
|
logger.error(
|
|
"batch_register_memory failed for %d blocks, ret: %s",
|
|
len(weight_blocks_for_reg_mr),
|
|
ret.to_string(),
|
|
)
|
|
return False
|
|
|
|
self.rfork_transfer_engine_weights_info_dict = weight_mr_dict
|
|
self.registered_weight_blocks = weight_blocks_for_reg_mr
|
|
|
|
logger.info(
|
|
"register_memory_region time: %.4fs",
|
|
time.time() - start_reg_mr_tic,
|
|
)
|
|
return True
|
|
|
|
def unregister_memory_region(self) -> bool:
|
|
transfer_engine = self._get_transfer_engine()
|
|
start_unreg_mr_tic = time.time()
|
|
ret = transfer_engine.batch_unregister_memory([address for address, _ in self.registered_weight_blocks])
|
|
if ret.is_error():
|
|
logger.error(
|
|
"batch_unregister_memory failed for %d blocks, ret: %s",
|
|
len(self.registered_weight_blocks),
|
|
ret.to_string(),
|
|
)
|
|
return False
|
|
self.rfork_transfer_engine_weights_info_dict = None
|
|
self.registered_weight_blocks = []
|
|
logger.info(
|
|
"unregister_memory_region time: %.4fs",
|
|
time.time() - start_unreg_mr_tic,
|
|
)
|
|
return True
|
|
|
|
def recv_from_source(
|
|
self,
|
|
model,
|
|
seed_instance_ip,
|
|
seed_instance_service_port,
|
|
local_seed_key,
|
|
):
|
|
transfer_engine = self._get_transfer_engine()
|
|
seed_url = f"http://{seed_instance_ip}:{seed_instance_service_port}"
|
|
seed_session_id, seed_weight_info = get_remote_instance_transfer_engine_info(seed_url, local_seed_key)
|
|
if seed_session_id is None or seed_weight_info is None:
|
|
logger.error("Cannot get transfer engine session or weight info.")
|
|
return False
|
|
|
|
seed_ptr_list = []
|
|
client_ptr_list = []
|
|
client_len_list = []
|
|
for name, tensor in model.named_parameters():
|
|
weight_info = seed_weight_info.get(name, None)
|
|
if weight_info is None:
|
|
logger.error("Cannot find weight info for %s.", name)
|
|
return False
|
|
|
|
seed_ptr, seed_len, seed_size = weight_info
|
|
if seed_len != tensor.numel() or seed_size != tensor.element_size():
|
|
logger.error(
|
|
"Weight info mismatch for %s, expected (%s, %s), got (%s, %s)",
|
|
name,
|
|
seed_len,
|
|
seed_size,
|
|
tensor.numel(),
|
|
tensor.element_size(),
|
|
)
|
|
return False
|
|
|
|
seed_ptr_list.append(seed_ptr)
|
|
client_ptr_list.append(tensor.data_ptr())
|
|
client_len_list.append(tensor.numel() * tensor.element_size())
|
|
|
|
start_transfer_tic = time.time()
|
|
ret = transfer_engine.batch_transfer_sync_read(
|
|
seed_session_id,
|
|
client_ptr_list,
|
|
seed_ptr_list,
|
|
client_len_list,
|
|
)
|
|
if ret.is_error():
|
|
logger.error("Failed to transfer weights from remote instance, ret=%s", ret.to_string())
|
|
return False
|
|
|
|
logger.info("transfer weights time: %.4fs", time.time() - start_transfer_tic)
|
|
return True
|
|
|
|
|
|
def get_remote_instance_transfer_engine_info(seed_url: str, local_seed_key: str):
|
|
try:
|
|
response = requests.get(
|
|
f"{seed_url}/get_rfork_transfer_engine_info",
|
|
params={"seed_key": local_seed_key},
|
|
)
|
|
if response.status_code != 200:
|
|
logger.error("request.get failed: %s", response.status_code)
|
|
return None, None
|
|
|
|
data = response.json()
|
|
info = data.get("rfork_transfer_engine_info", None)
|
|
if info is not None and isinstance(info, list) and len(info) == 2:
|
|
return info[0], info[1]
|
|
|
|
logger.error("Failed to get `rfork_transfer_engine_info` in response.")
|
|
return None, None
|
|
except Exception as e:
|
|
logger.error("Exception: %s", e)
|
|
return None, None
|