[Misc] Add a model loader that utilizes HCCL for weight loading (#2888)

### What this PR does / why we need it?

This PR introduces a new model loader called Netloader, which leverages
high-bandwidth P2P direct transfer between NPU cards to achieve weight
loading. Netloader is implemented as a plugin through the newly added
'register_model_loader' function in vLLM 0.10. It facilitates the
process of weight loading by sending weights from a pre-loaded model
(server) to an empty model of a newly started instance (client). The
server operates concurrently with normal inference tasks through
sub-threads and the 'stateless_init_torch_distributed_process_group' in
vLLM. The client initiates a transfer request after verifying that the
model and partitioning method are the same as the server's, and uses
HCCL's collective communication (send/recv) to load the weights in the
order they are stored in the model.

Application Scenarios:
1. Significantly Reduces Inference Instance Startup Time By reusing the
weights of already loaded instances and performing high-speed transfers
directly between computing cards, this method reduces model loading
latency compared to traditional remote/local pull methods.
2. Reduces Network and Storage Pressure Avoids the need to repeatedly
download weight files from remote repositories, reducing the impact on
centralized storage and network traffic, thereby enhancing overall
system stability and service quality.
3. Improves Resource Utilization and Reduces Costs Accelerating the
loading process reduces reliance on redundant computing pools, allowing
computing resources to be elastically scaled and reclaimed as needed.
4. Enhances Business Continuity and High Availability In fault recovery
scenarios, new instances can quickly take over existing services,
avoiding prolonged business interruptions and improving the system's
high availability and user experience.

### Does this PR introduce _any_ user-facing change?

Netloader utilizes the existing --load-format=netloader and
--model-loader-extra-config to be activated. The
model-loader-extra-config needs to be input as a JSON string (as it is
now)

Afterwards, you can check whether the outputs for the same sentence are
consistent when the temperature is set to 0.

Signed-off-by: destinysky <kangrui10@126.com>

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
Rui Kang
2025-10-23 15:56:07 +08:00
committed by GitHub
parent 807686dec9
commit 427b17e2da
19 changed files with 1999 additions and 1 deletions

View File

@@ -0,0 +1,170 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch_npu
from vllm.distributed.utils import (
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group)
from vllm.logger import logger
class P2PLoad:
"""
Class for receiving model parameters in a distributed manner using HCCL backend.
"""
def __init__(
self,
world_name: str,
source_ip: str,
source_port: int,
):
"""
Initializes the P2PLoad instance.
Parameters:
- world_name: The name of the distributed group.
- source_ip: The IP address of the source node.
- source_port: The port number for the source node.
"""
self.world_name = world_name
self.source_ip = source_ip
self.source_port = source_port
def load(self, model):
"""
Loads the model parameters using HCCL backend.
Parameters:
- model: The model whose parameters are to be loaded.
Returns:
- The model if loading is successful, otherwise None.
"""
model_device = next(model.parameters()).device
logger.info(
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
receiver_pg = None
loaded_model = None
try:
receiver_pg = stateless_init_torch_distributed_process_group(
host=self.world_name.split(":")[0],
port=self.source_port,
rank=0,
world_size=2,
backend='hccl',
)
logger.info(
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
logger.info(
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
logger.info(f"Model device: {model_device}")
trans_stream = torch_npu.npu.Stream()
with torch_npu.npu.stream(trans_stream):
for name, param in model.named_parameters():
if len(param.shape) == 0:
continue
receiver_pg.recv([param], 1, 0).wait()
torch.distributed.barrier(group=receiver_pg,
device_ids=[model_device.index])
torch_npu.npu.synchronize(trans_stream)
logger.info(
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
loaded_model = model
except Exception as e:
logger.error("Failed to recv model: {}".format(e))
finally:
if receiver_pg:
stateless_destroy_torch_distributed_process_group(receiver_pg)
return loaded_model
class P2PSend:
"""
Class for sending model parameters in a distributed manner using HCCL backend.
"""
def __init__(self, listen_ip: str, listen_port: int, comm_name: str):
"""
Initializes the P2PSend instance.
Parameters:
- listen_ip: The IP address to listen on.
- listen_port: The port number to listen on.
- comm_name: The name of the communication group.
"""
self.listen_ip = listen_ip
self.listen_port = listen_port
self.comm_name = comm_name
def send(self, model, int8_params: dict):
"""
Sends the model parameters using HCCL backend.
Parameters:
- model: The model whose parameters are to be sent.
- int8_params: Dictionary of parameters that are in int8 format.
"""
model_device = next(model.parameters()).device
torch.npu.set_device(model_device)
logger.info(
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
sender_pg = None
try:
sender_pg = stateless_init_torch_distributed_process_group(
host=self.comm_name.split(":")[0],
port=self.listen_port,
rank=1,
world_size=2,
backend='hccl',
)
logger.info(
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
logger.info(
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
logger.info(f"Model device: {model_device}")
trans_stream = torch_npu.npu.Stream()
with torch_npu.npu.stream(trans_stream):
for name, param in model.named_parameters():
if "aclnn_input_scale" in name:
continue
if name in int8_params:
sender_pg.send([int8_params[name].to(model_device)], 0,
0).wait()
else:
sender_pg.send([param.contiguous()], 0, 0).wait()
torch.distributed.barrier(group=sender_pg,
device_ids=[model_device.index])
torch_npu.npu.synchronize(trans_stream)
logger.info(
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
finally:
if sender_pg:
stateless_destroy_torch_distributed_process_group(sender_pg)