[Misc] Add a model loader that utilizes HCCL for weight loading (#2888)
### What this PR does / why we need it? This PR introduces a new model loader called Netloader, which leverages high-bandwidth P2P direct transfer between NPU cards to achieve weight loading. Netloader is implemented as a plugin through the newly added 'register_model_loader' function in vLLM 0.10. It facilitates the process of weight loading by sending weights from a pre-loaded model (server) to an empty model of a newly started instance (client). The server operates concurrently with normal inference tasks through sub-threads and the 'stateless_init_torch_distributed_process_group' in vLLM. The client initiates a transfer request after verifying that the model and partitioning method are the same as the server's, and uses HCCL's collective communication (send/recv) to load the weights in the order they are stored in the model. Application Scenarios: 1. Significantly Reduces Inference Instance Startup Time By reusing the weights of already loaded instances and performing high-speed transfers directly between computing cards, this method reduces model loading latency compared to traditional remote/local pull methods. 2. Reduces Network and Storage Pressure Avoids the need to repeatedly download weight files from remote repositories, reducing the impact on centralized storage and network traffic, thereby enhancing overall system stability and service quality. 3. Improves Resource Utilization and Reduces Costs Accelerating the loading process reduces reliance on redundant computing pools, allowing computing resources to be elastically scaled and reclaimed as needed. 4. Enhances Business Continuity and High Availability In fault recovery scenarios, new instances can quickly take over existing services, avoiding prolonged business interruptions and improving the system's high availability and user experience. ### Does this PR introduce _any_ user-facing change? Netloader utilizes the existing --load-format=netloader and --model-loader-extra-config to be activated. The model-loader-extra-config needs to be input as a JSON string (as it is now) Afterwards, you can check whether the outputs for the same sentence are consistent when the temperature is set to 0. Signed-off-by: destinysky <kangrui10@126.com> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
0
vllm_ascend/model_loader/__init__.py
Normal file
0
vllm_ascend/model_loader/__init__.py
Normal file
20
vllm_ascend/model_loader/netloader/__init__.py
Normal file
20
vllm_ascend/model_loader/netloader/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
|
||||
def register_netloader():
|
||||
"""Register the NetLoader plugin."""
|
||||
from .netloader import ModelNetLoaderElastic # noqa
|
||||
170
vllm_ascend/model_loader/netloader/executor/elastic_load.py
Normal file
170
vllm_ascend/model_loader/netloader/executor/elastic_load.py
Normal file
@@ -0,0 +1,170 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.distributed.utils import (
|
||||
stateless_destroy_torch_distributed_process_group,
|
||||
stateless_init_torch_distributed_process_group)
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
class P2PLoad:
|
||||
"""
|
||||
Class for receiving model parameters in a distributed manner using HCCL backend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
world_name: str,
|
||||
source_ip: str,
|
||||
source_port: int,
|
||||
):
|
||||
"""
|
||||
Initializes the P2PLoad instance.
|
||||
|
||||
Parameters:
|
||||
- world_name: The name of the distributed group.
|
||||
- source_ip: The IP address of the source node.
|
||||
- source_port: The port number for the source node.
|
||||
"""
|
||||
self.world_name = world_name
|
||||
self.source_ip = source_ip
|
||||
self.source_port = source_port
|
||||
|
||||
def load(self, model):
|
||||
"""
|
||||
Loads the model parameters using HCCL backend.
|
||||
|
||||
Parameters:
|
||||
- model: The model whose parameters are to be loaded.
|
||||
|
||||
Returns:
|
||||
- The model if loading is successful, otherwise None.
|
||||
"""
|
||||
model_device = next(model.parameters()).device
|
||||
logger.info(
|
||||
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
receiver_pg = None
|
||||
loaded_model = None
|
||||
try:
|
||||
receiver_pg = stateless_init_torch_distributed_process_group(
|
||||
host=self.world_name.split(":")[0],
|
||||
port=self.source_port,
|
||||
rank=0,
|
||||
world_size=2,
|
||||
backend='hccl',
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
logger.info(f"Model device: {model_device}")
|
||||
|
||||
trans_stream = torch_npu.npu.Stream()
|
||||
with torch_npu.npu.stream(trans_stream):
|
||||
for name, param in model.named_parameters():
|
||||
if len(param.shape) == 0:
|
||||
continue
|
||||
receiver_pg.recv([param], 1, 0).wait()
|
||||
torch.distributed.barrier(group=receiver_pg,
|
||||
device_ids=[model_device.index])
|
||||
|
||||
torch_npu.npu.synchronize(trans_stream)
|
||||
|
||||
logger.info(
|
||||
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
|
||||
)
|
||||
loaded_model = model
|
||||
except Exception as e:
|
||||
logger.error("Failed to recv model: {}".format(e))
|
||||
finally:
|
||||
if receiver_pg:
|
||||
stateless_destroy_torch_distributed_process_group(receiver_pg)
|
||||
return loaded_model
|
||||
|
||||
|
||||
class P2PSend:
|
||||
"""
|
||||
Class for sending model parameters in a distributed manner using HCCL backend.
|
||||
"""
|
||||
|
||||
def __init__(self, listen_ip: str, listen_port: int, comm_name: str):
|
||||
"""
|
||||
Initializes the P2PSend instance.
|
||||
|
||||
Parameters:
|
||||
- listen_ip: The IP address to listen on.
|
||||
- listen_port: The port number to listen on.
|
||||
- comm_name: The name of the communication group.
|
||||
"""
|
||||
self.listen_ip = listen_ip
|
||||
self.listen_port = listen_port
|
||||
self.comm_name = comm_name
|
||||
|
||||
def send(self, model, int8_params: dict):
|
||||
"""
|
||||
Sends the model parameters using HCCL backend.
|
||||
|
||||
Parameters:
|
||||
- model: The model whose parameters are to be sent.
|
||||
- int8_params: Dictionary of parameters that are in int8 format.
|
||||
"""
|
||||
model_device = next(model.parameters()).device
|
||||
torch.npu.set_device(model_device)
|
||||
logger.info(
|
||||
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
sender_pg = None
|
||||
try:
|
||||
sender_pg = stateless_init_torch_distributed_process_group(
|
||||
host=self.comm_name.split(":")[0],
|
||||
port=self.listen_port,
|
||||
rank=1,
|
||||
world_size=2,
|
||||
backend='hccl',
|
||||
)
|
||||
logger.info(
|
||||
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(
|
||||
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
logger.info(f"Model device: {model_device}")
|
||||
|
||||
trans_stream = torch_npu.npu.Stream()
|
||||
with torch_npu.npu.stream(trans_stream):
|
||||
for name, param in model.named_parameters():
|
||||
if "aclnn_input_scale" in name:
|
||||
continue
|
||||
if name in int8_params:
|
||||
sender_pg.send([int8_params[name].to(model_device)], 0,
|
||||
0).wait()
|
||||
else:
|
||||
sender_pg.send([param.contiguous()], 0, 0).wait()
|
||||
torch.distributed.barrier(group=sender_pg,
|
||||
device_ids=[model_device.index])
|
||||
torch_npu.npu.synchronize(trans_stream)
|
||||
logger.info(
|
||||
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
|
||||
)
|
||||
finally:
|
||||
if sender_pg:
|
||||
stateless_destroy_torch_distributed_process_group(sender_pg)
|
||||
408
vllm_ascend/model_loader/netloader/interaction/elastic.py
Normal file
408
vllm_ascend/model_loader/netloader/interaction/elastic.py
Normal file
@@ -0,0 +1,408 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import json
|
||||
import re
|
||||
import socket
|
||||
import threading
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from vllm.logger import logger
|
||||
|
||||
from ..executor.elastic_load import P2PSend
|
||||
from ..utils import find_free_port
|
||||
|
||||
|
||||
class ElasticClient:
|
||||
"""
|
||||
Class for handling the client-side logic of Netloader of models.
|
||||
"""
|
||||
|
||||
def __init__(self, sources: list[str], device_id: int, model_path: str,
|
||||
tp: int, pp: int):
|
||||
"""
|
||||
Initializes the ElasticClient instance.
|
||||
|
||||
Parameters:
|
||||
- sources: List of source addresses in the format IP:port.
|
||||
- device_id: The ID of the current device.
|
||||
- model_path: The path to the model.
|
||||
- tp: Tensor parallel size.
|
||||
- pp: Pipeline parallel size.
|
||||
"""
|
||||
self.sources = sources
|
||||
self.device_id = device_id
|
||||
self.model_path = model_path
|
||||
self.tp = tp
|
||||
self.pp = pp
|
||||
|
||||
self.s: Optional[socket.socket] = None
|
||||
self.ack: Optional[Tuple[str, int]] = None
|
||||
self.server_addr: Optional[str] = None
|
||||
self.server_port: Optional[int] = None
|
||||
|
||||
for source in self.sources:
|
||||
try:
|
||||
ip, port_str = source.split(':')
|
||||
port = int(port_str)
|
||||
except Exception as e:
|
||||
logger.error(f"IP format error: {source}, detail: {e}")
|
||||
continue
|
||||
|
||||
self.server_addr = ip
|
||||
self.server_port = port
|
||||
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
logger.info(
|
||||
f"Start connection to server: {self.server_addr}:{self.server_port}"
|
||||
)
|
||||
sock.connect((self.server_addr, self.server_port))
|
||||
logger.info(
|
||||
f"Finish connection to server: {self.server_addr}:{self.server_port}"
|
||||
)
|
||||
sock.settimeout(60)
|
||||
|
||||
self.s = sock
|
||||
self.ack = self.register(device_id, model_path, tp, pp)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Connect to {source} fails, detail: {e}")
|
||||
if sock is not None:
|
||||
try:
|
||||
sock.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.s = None
|
||||
self.ack = None
|
||||
self.server_addr = None
|
||||
self.server_port = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""
|
||||
Closes the socket connection.
|
||||
"""
|
||||
if self.s is not None:
|
||||
try:
|
||||
self.s.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing socket: {e}")
|
||||
finally:
|
||||
self.s = None
|
||||
|
||||
def __enter__(self) -> "ElasticClient":
|
||||
"""
|
||||
Context manager enter method.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""
|
||||
Context manager exit method.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor method to ensure socket is closed.
|
||||
"""
|
||||
try:
|
||||
self.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send_str(self, data_str: str) -> None:
|
||||
"""
|
||||
Sends a string over the socket connection.
|
||||
|
||||
Parameters:
|
||||
- data_str: The string to be sent.
|
||||
"""
|
||||
if self.s is None:
|
||||
raise RuntimeError("Socket was not created correctly.")
|
||||
self.s.send(data_str.encode("utf-8"))
|
||||
|
||||
def recv_str(self, buffer_size: int = 1024) -> str:
|
||||
"""
|
||||
Receives a string over the socket connection.
|
||||
|
||||
Parameters:
|
||||
- buffer_size: The size of the buffer for receiving data.
|
||||
|
||||
Returns:
|
||||
- The received string.
|
||||
"""
|
||||
if self.s is None:
|
||||
raise RuntimeError("Socket was not created correctly.")
|
||||
data_str = self.s.recv(buffer_size).decode("utf-8")
|
||||
return data_str
|
||||
|
||||
def register(self, device_id: int, model_path: str, tp: int,
|
||||
pp: int) -> Tuple[str, int]:
|
||||
"""
|
||||
Registers the client with the server.
|
||||
|
||||
Parameters:
|
||||
- device_id: The ID of the current device.
|
||||
- model_path: The path to the model.
|
||||
- tp: Tensor parallel size.
|
||||
- pp: Pipeline parallel size.
|
||||
|
||||
Returns:
|
||||
- A tuple containing the communication name and port.
|
||||
"""
|
||||
free_port = find_free_port()
|
||||
data = {
|
||||
"label": "JOIN",
|
||||
"content": {
|
||||
'device_id': device_id,
|
||||
'model_path': model_path,
|
||||
'tp': tp,
|
||||
'pp': pp,
|
||||
'port': free_port
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
self.send_str(json.dumps(data))
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Send data {data} to server fails, detail: {e}")
|
||||
|
||||
try:
|
||||
ack_str = self.recv_str()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Receive data from server fails, detail: {e}")
|
||||
|
||||
try:
|
||||
ack = json.loads(ack_str)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
|
||||
)
|
||||
|
||||
logger.info(f"Receive ack: {ack}")
|
||||
|
||||
if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack
|
||||
and ack["content"] is not None and "name" in ack["content"]):
|
||||
return (ack["content"]["name"], free_port)
|
||||
elif ("label" in ack and ack["label"] == 'JOIN_NACK'
|
||||
and "content" in ack):
|
||||
raise RuntimeError(
|
||||
f"Receive nack from server, reason: {ack['content']}")
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Receive ack {ack} from server does not contain required fields"
|
||||
)
|
||||
|
||||
|
||||
class ElasticServer:
|
||||
"""
|
||||
Class for handling the server-side logic of Netloader of models.
|
||||
"""
|
||||
|
||||
def __init__(self, addr: str, port: int, model, device_id: int,
|
||||
model_path: str, tp: int, pp: int, int8_cache: str,
|
||||
int8_cache_name: Optional[List[str]]):
|
||||
"""
|
||||
Initializes the ElasticServer instance.
|
||||
|
||||
Parameters:
|
||||
- addr: The IP address to listen on.
|
||||
- port: The port number to listen on.
|
||||
- model: The model to be served.
|
||||
- device_id: The ID of the current device (i.e. global rank).
|
||||
- model_path: The path to the model.
|
||||
- tp: Tensor parallel size.
|
||||
- pp: Pipeline parallel size.
|
||||
- int8_cache: The type of caching for int8 parameters (HBM, DRAM, or no).
|
||||
- int8_cache_name: List of parameter names to be cached.
|
||||
"""
|
||||
self.addr = addr
|
||||
self.port = port
|
||||
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
self.s.bind((self.addr, self.port))
|
||||
self.s.listen(256)
|
||||
|
||||
self.model = model
|
||||
self.device_id = device_id
|
||||
self.model_path = model_path
|
||||
self.tp = tp
|
||||
self.pp = pp
|
||||
|
||||
self.original_int8 = {}
|
||||
int8_pattern = "|".join(
|
||||
map(re.escape,
|
||||
int8_cache_name)) if int8_cache_name is not None else "(?:)"
|
||||
for name, param in self.model.named_parameters():
|
||||
if param.dtype == torch.int8:
|
||||
if int8_cache == 'hbm':
|
||||
if int8_cache_name is None or (
|
||||
int8_cache_name is not None
|
||||
and re.search(int8_pattern, name) is not None):
|
||||
try:
|
||||
self.original_int8[name] = param.data.clone(
|
||||
).detach()
|
||||
except RuntimeError as e:
|
||||
logger.error(
|
||||
f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
|
||||
)
|
||||
self.original_int8[name] = param.data.cpu()
|
||||
|
||||
elif int8_cache == 'dram':
|
||||
if int8_cache_name is None or (
|
||||
int8_cache_name is not None
|
||||
and re.search(int8_pattern, name) is not None):
|
||||
self.original_int8[name] = param.data.cpu()
|
||||
elif int8_cache == 'no':
|
||||
pass
|
||||
else:
|
||||
logger.warning(
|
||||
f"int8_cache should be selected in [HBM, DRAM], but got {int8_cache}, change to no cache"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}"
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
"""
|
||||
Destructor method to ensure socket is closed.
|
||||
"""
|
||||
self.s.close()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Starts the server to handle incoming connections.
|
||||
"""
|
||||
handler_thread = threading.Thread(target=self.elastic_client_handler)
|
||||
handler_thread.daemon = True
|
||||
handler_thread.start()
|
||||
|
||||
def elastic_client_handler(self):
|
||||
"""
|
||||
Handles incoming client connections.
|
||||
"""
|
||||
while True:
|
||||
conn, addr = self.s.accept()
|
||||
logger.info("Accept new connection from {}:{}...".format(*addr))
|
||||
self.register_handler(conn, addr)
|
||||
|
||||
def register_handler(self, conn, addr, buffer_size=1024):
|
||||
"""
|
||||
Handles the registration of a client.
|
||||
|
||||
Parameters:
|
||||
- conn: The connection socket.
|
||||
- addr: The address of the client.
|
||||
- buffer_size: The size of the buffer for receiving data.
|
||||
"""
|
||||
data_str = conn.recv(buffer_size).decode("utf-8")
|
||||
if not data_str:
|
||||
return
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except Exception:
|
||||
logger.error(f"Failed to load {data_str} as JSON string")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
def is_valid_data(data):
|
||||
"""
|
||||
Validates the received data.
|
||||
|
||||
Parameters:
|
||||
- data: The data to be validated.
|
||||
|
||||
Returns:
|
||||
- True if the data is valid, otherwise False.
|
||||
"""
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
if data.get("label") != "JOIN":
|
||||
return False
|
||||
content = data.get("content")
|
||||
if not isinstance(content, dict):
|
||||
return False
|
||||
required_keys = ["device_id", "model_path", "tp", "pp", "port"]
|
||||
if not all(k in content for k in required_keys):
|
||||
return False
|
||||
port = content["port"]
|
||||
if not (isinstance(port, int) or
|
||||
(isinstance(port, str) and port.isdigit())):
|
||||
return False
|
||||
return True
|
||||
|
||||
comm_name = None
|
||||
if is_valid_data(data):
|
||||
device_id = int(data["content"]["device_id"])
|
||||
model_path = data["content"]["model_path"]
|
||||
tp = int(data["content"]["tp"])
|
||||
pp = int(data["content"]["pp"])
|
||||
|
||||
if int(self.device_id
|
||||
) == device_id and self.model_path == model_path and int(
|
||||
self.tp) == tp and int(self.pp) == pp:
|
||||
comm_name = str(addr[0]) + ":" + str(addr[1])
|
||||
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
||||
)
|
||||
ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
|
||||
}
|
||||
else:
|
||||
logger.warning(
|
||||
f"Received data does not contain required fields: {data}")
|
||||
ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data does not contain required fields: {data}"
|
||||
}
|
||||
|
||||
try:
|
||||
ack_str = json.dumps(ack).encode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert {ack} to JSON format, details: {e}")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
try:
|
||||
conn.send(ack_str)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send {ack} to {addr}, details: {e}")
|
||||
conn.close()
|
||||
return
|
||||
|
||||
if ack["content"] and isinstance(ack["content"],
|
||||
dict) and 'name' in ack["content"]:
|
||||
try:
|
||||
p2psend = P2PSend(self.addr, data["content"]["port"],
|
||||
ack["content"]["name"])
|
||||
p2psend.send(self.model, self.original_int8)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"P2PSend Failed to send model to {self.addr}, details: {e}"
|
||||
)
|
||||
conn.close()
|
||||
84
vllm_ascend/model_loader/netloader/load.py
Normal file
84
vllm_ascend/model_loader/netloader/load.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import time
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from .executor.elastic_load import P2PLoad
|
||||
from .interaction.elastic import ElasticClient
|
||||
|
||||
|
||||
def elastic_load(
|
||||
model,
|
||||
device_id: int,
|
||||
model_path: str,
|
||||
sources: list,
|
||||
tp: int,
|
||||
pp: int,
|
||||
):
|
||||
"""
|
||||
Loads a model using elastic loading across multiple devices.
|
||||
|
||||
Parameters:
|
||||
- model: The model instance to be loaded.
|
||||
- device_id: The ID of the current device (i.e. global rank).
|
||||
- model_path: The path to the model file.
|
||||
- sources: A list of source configurations, each containing device_id and sources.
|
||||
- tp: Tensor parallel size, indicating the number of devices for tensor parallelism.
|
||||
- pp: Pipeline parallel size, indicating the number of devices for pipeline parallelism.
|
||||
|
||||
Returns:
|
||||
- The loaded model if successful, otherwise None.
|
||||
"""
|
||||
|
||||
# Filter sources for the current device
|
||||
sources_this_device = []
|
||||
for s in sources:
|
||||
if isinstance(
|
||||
s, dict
|
||||
) and "device_id" in s and s["device_id"] == device_id and isinstance(
|
||||
s["sources"], list):
|
||||
sources_this_device += s["sources"]
|
||||
if len(sources_this_device) == 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Initialize the interaction layer with the ElasticClient
|
||||
with ElasticClient(sources_this_device, device_id, model_path, tp,
|
||||
pp) as client_interaction_layer:
|
||||
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
|
||||
raise RuntimeError(
|
||||
"Failed to initialize ElasticClient: socket or server_addr is None"
|
||||
)
|
||||
ack = client_interaction_layer.ack
|
||||
if ack is None:
|
||||
raise RuntimeError("ElasticClient.register did not return ack")
|
||||
|
||||
t0 = time.perf_counter()
|
||||
elastic_loader = P2PLoad(ack[0],
|
||||
client_interaction_layer.server_addr,
|
||||
ack[1])
|
||||
model_loaded = elastic_loader.load(model=model)
|
||||
if model_loaded is None:
|
||||
logger.error("Failed to load model")
|
||||
return None
|
||||
logger.info("Finish elastic load (duration: {}s)".format(
|
||||
time.perf_counter() - t0))
|
||||
return model_loaded
|
||||
except Exception as e:
|
||||
logger.error(f"elastic_load error: {e}")
|
||||
return None
|
||||
324
vllm_ascend/model_loader/netloader/netloader.py
Normal file
324
vllm_ascend/model_loader/netloader/netloader.py
Normal file
@@ -0,0 +1,324 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import gc
|
||||
import json
|
||||
import time
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.model_loader import register_model_loader
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
from .interaction.elastic import ElasticServer
|
||||
from .load import elastic_load
|
||||
from .utils import find_free_port, is_valid_path_prefix
|
||||
|
||||
|
||||
@register_model_loader("netloader")
|
||||
class ModelNetLoaderElastic(BaseModelLoader):
|
||||
"""
|
||||
A model loader that uses elastic loading for loading weights.
|
||||
"""
|
||||
source: Optional[List[dict]]
|
||||
model_path: Optional[str]
|
||||
listen_port: Optional[int]
|
||||
int8_cache: str
|
||||
int8_cache_name: Optional[List[str]]
|
||||
output_prefix: Optional[str]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
"""
|
||||
Initializes the ModelNetLoaderElastic with configuration.
|
||||
|
||||
Parameters:
|
||||
- load_config: Configuration for loading the model.
|
||||
"""
|
||||
super().__init__(load_config)
|
||||
|
||||
config = None
|
||||
|
||||
# Try to read config file at first
|
||||
extra = load_config.model_loader_extra_config
|
||||
if extra and "CONFIG_FILE" in extra:
|
||||
try:
|
||||
logger.info(
|
||||
f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..."
|
||||
)
|
||||
with open(extra["CONFIG_FILE"], 'r') as f:
|
||||
config = json.load(f)
|
||||
except FileNotFoundError:
|
||||
logger.error("CONFIG_FILE not found")
|
||||
except json.JSONDecodeError:
|
||||
logger.error("CONFIG_FILE is not a valid JSON file")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error while reading CONFIG_FILE: {e}")
|
||||
|
||||
if config is None and extra:
|
||||
logger.info("Reading configs in model_loader_extra_config ...")
|
||||
config = extra
|
||||
config = config or {}
|
||||
|
||||
for key, attr, checker, caster, default in [
|
||||
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v,
|
||||
None),
|
||||
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v,
|
||||
None),
|
||||
("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or
|
||||
(isinstance(v, str) and v.isdigit()), lambda v: int(v), None),
|
||||
("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v.
|
||||
lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'),
|
||||
("INT8_CACHE_NAME", "int8_cache_name",
|
||||
lambda v: isinstance(v, list), lambda v: v, None),
|
||||
("OUTPUT_PREFIX", "output_prefix",
|
||||
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
|
||||
lambda v: v, None),
|
||||
]:
|
||||
v = config.get(key, default)
|
||||
if not checker(v):
|
||||
v = default
|
||||
else:
|
||||
v = caster(v)
|
||||
setattr(self, attr, v)
|
||||
|
||||
logger.info(
|
||||
"Initializing elastic Netloader with config: "
|
||||
"MODEL=%s, LISTEN_PORT=%s,"
|
||||
"SOURCE=%s, INT8_CACHE=%s, INT8_CACHE_NAME=%s,"
|
||||
"OUTPUT_PREFIX=%s)",
|
||||
self.model_path,
|
||||
self.listen_port,
|
||||
self.source,
|
||||
self.int8_cache,
|
||||
self.int8_cache_name,
|
||||
self.output_prefix,
|
||||
)
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
"""
|
||||
Loads the model using the specified configuration.
|
||||
|
||||
Parameters:
|
||||
- vllm_config: Configuration for the VLLM.
|
||||
- model_config: Configuration for the model.
|
||||
|
||||
Returns:
|
||||
- The loaded model.
|
||||
"""
|
||||
|
||||
device_config = vllm_config.device_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
need_process_weights_after_loading = False
|
||||
|
||||
if self.model_path is None:
|
||||
self.model_path = model_config.model
|
||||
logger.info(f"model_path is set to {self.model_path}")
|
||||
|
||||
device_id = torch.distributed.get_rank()
|
||||
|
||||
if (self.source is None or not isinstance(self.source, list)
|
||||
or device_id not in [
|
||||
one_device["device_id"] for one_device in self.source if
|
||||
isinstance(one_device, dict) and "device_id" in one_device
|
||||
]):
|
||||
logger.warning(
|
||||
"Did not get valid source info, use DefaultModelLoader")
|
||||
model, need_process_weights_after_loading = self.revert_to_default(
|
||||
model_config, vllm_config, device_config)
|
||||
|
||||
else:
|
||||
target_device = torch.device(device_config.device)
|
||||
|
||||
vllm_config_backup = deepcopy(vllm_config)
|
||||
model_config_backup = deepcopy(model_config)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
start_elastic_load = time.perf_counter()
|
||||
model = elastic_load(
|
||||
model=model,
|
||||
device_id=device_id,
|
||||
model_path=self.model_path,
|
||||
sources=self.source,
|
||||
tp=parallel_config.tensor_parallel_size,
|
||||
pp=parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
end_elastic_load = time.perf_counter()
|
||||
logger.info(
|
||||
f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}"
|
||||
)
|
||||
need_process_weights_after_loading = True
|
||||
|
||||
if model is None:
|
||||
logger.warning(
|
||||
"Netloader elastic loading fails, use load format DefaultModelLoader"
|
||||
)
|
||||
|
||||
vllm_config = vllm_config_backup
|
||||
model_config = model_config_backup
|
||||
|
||||
del model
|
||||
gc.collect()
|
||||
if device_config.device_type == 'npu':
|
||||
logger.info("Empty NPU cache")
|
||||
torch.npu.empty_cache()
|
||||
elif device_config.device_type == 'cuda':
|
||||
logger.info("Empty CUDA cache")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model, need_process_weights_after_loading = self.revert_to_default(
|
||||
model_config, vllm_config, device_config)
|
||||
|
||||
start_elastic_server = time.perf_counter()
|
||||
# start elastic server
|
||||
if model is not None and (
|
||||
(self.listen_port and self.listen_port in range(1024, 65535)) or
|
||||
(self.listen_port is None)):
|
||||
from vllm.utils import get_ip
|
||||
driver_ip = get_ip()
|
||||
|
||||
if driver_ip == '0.0.0.0':
|
||||
logger.error(
|
||||
"Driver IP is not set, skip to start Netloader server")
|
||||
else:
|
||||
if self.listen_port is None:
|
||||
self.listen_port = find_free_port()
|
||||
else:
|
||||
self.listen_port += device_id
|
||||
|
||||
logger.info(
|
||||
f"Start elastic Netloader server, rank: {device_id}, listen port: {driver_ip}:{self.listen_port}"
|
||||
)
|
||||
|
||||
if self.output_prefix is not None:
|
||||
try:
|
||||
with open(self.output_prefix + str(device_id) + '.txt',
|
||||
'w') as file:
|
||||
file.write(f"{driver_ip}:{self.listen_port}")
|
||||
logger.info(
|
||||
f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}"
|
||||
)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
f"File path {self.output_prefix + str(device_id)} does not exist."
|
||||
)
|
||||
except PermissionError:
|
||||
logger.error(
|
||||
f"No permission to write to file {self.output_prefix + str(device_id)}."
|
||||
)
|
||||
except IOError as e:
|
||||
logger.error(
|
||||
f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unknown error: {e}")
|
||||
|
||||
try:
|
||||
assert isinstance(
|
||||
self.listen_port, int
|
||||
), f"listen port should be int but get {self.listen_port}"
|
||||
|
||||
elastic_server = ElasticServer(
|
||||
driver_ip, self.listen_port, model, device_id,
|
||||
self.model_path, parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size,
|
||||
self.int8_cache, self.int8_cache_name)
|
||||
elastic_server.start()
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to start Netloader server for rank: {device_id}, details: {e}"
|
||||
)
|
||||
else:
|
||||
logger.info("Skip to start Netloader server")
|
||||
|
||||
end_elastic_server = time.perf_counter()
|
||||
logger.info(
|
||||
f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}"
|
||||
)
|
||||
|
||||
if need_process_weights_after_loading:
|
||||
process_weights_after_loading(model, model_config,
|
||||
torch.device(device_config.device))
|
||||
|
||||
if model is None:
|
||||
logger.error("NetLoader elastic loads model fails")
|
||||
return None
|
||||
|
||||
return model.eval()
|
||||
|
||||
def revert_to_default(self, model_config, vllm_config,
|
||||
device_config) -> Tuple[nn.Module, bool]:
|
||||
"""
|
||||
Reverts to the default model loading logic when elastic loading fails or is not applicable.
|
||||
|
||||
This method resets the loader's extra config and load format to defaults,
|
||||
then delegates model loading to a DefaultModelLoader.
|
||||
If quantization is enabled, it will load the model and then run the
|
||||
processing of weights (i.e. applying quantization adjustments) before returning.
|
||||
|
||||
Parameters:
|
||||
- model_config: Configuration describing model architecture, quantization, etc.
|
||||
- vllm_config: Configuration for vLLM (device, parallelism, dtype, etc).
|
||||
- device_config: Configuration for the target device (device type, device id, etc).
|
||||
|
||||
Returns:
|
||||
- A tuple (model, need_process_weights_after_loading):
|
||||
* model: The loaded `nn.Module` under default loading logic.
|
||||
* need_process_weights_after_loading: A boolean flag indicating whether
|
||||
weights post-processing (e.g. quantization adjustments) still needs to be applied.
|
||||
"""
|
||||
self.load_config.model_loader_extra_config = {}
|
||||
self.load_config.load_format = "auto"
|
||||
default_model_loader = DefaultModelLoader(self.load_config)
|
||||
|
||||
if model_config.quantization is None:
|
||||
model = default_model_loader.load_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
need_process_weights_after_loading = False
|
||||
else:
|
||||
logger.warning(
|
||||
"Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading "
|
||||
)
|
||||
need_process_weights_after_loading = True
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
default_model_loader.load_weights(model, model_config)
|
||||
model = model.eval()
|
||||
|
||||
return model, need_process_weights_after_loading
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
pass
|
||||
66
vllm_ascend/model_loader/netloader/utils.py
Normal file
66
vllm_ascend/model_loader/netloader/utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
def find_free_port():
|
||||
"""
|
||||
Finds a free port on the local machine.
|
||||
|
||||
Returns:
|
||||
- A free port number.
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def is_valid_path_prefix(path_prefix):
|
||||
"""
|
||||
Checks if the provided path prefix is valid.
|
||||
|
||||
Parameters:
|
||||
- path_prefix: The path prefix to validate.
|
||||
|
||||
Returns:
|
||||
- True if the path prefix is valid, otherwise False.
|
||||
"""
|
||||
if not path_prefix:
|
||||
return False
|
||||
|
||||
if re.search(r'[<>:"|?*]', path_prefix):
|
||||
logger.warning(
|
||||
f'The path prefix {path_prefix} contains illegal characters.')
|
||||
return False
|
||||
|
||||
if path_prefix.startswith('/') or path_prefix.startswith('\\'):
|
||||
if not os.path.exists(os.path.dirname(path_prefix)):
|
||||
logger.warning(
|
||||
f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.'
|
||||
)
|
||||
return False
|
||||
else:
|
||||
if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
|
||||
logger.warning(
|
||||
f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.'
|
||||
)
|
||||
return False
|
||||
return True
|
||||
Reference in New Issue
Block a user