[Misc] Add a model loader that utilizes HCCL for weight loading (#2888)
### What this PR does / why we need it? This PR introduces a new model loader called Netloader, which leverages high-bandwidth P2P direct transfer between NPU cards to achieve weight loading. Netloader is implemented as a plugin through the newly added 'register_model_loader' function in vLLM 0.10. It facilitates the process of weight loading by sending weights from a pre-loaded model (server) to an empty model of a newly started instance (client). The server operates concurrently with normal inference tasks through sub-threads and the 'stateless_init_torch_distributed_process_group' in vLLM. The client initiates a transfer request after verifying that the model and partitioning method are the same as the server's, and uses HCCL's collective communication (send/recv) to load the weights in the order they are stored in the model. Application Scenarios: 1. Significantly Reduces Inference Instance Startup Time By reusing the weights of already loaded instances and performing high-speed transfers directly between computing cards, this method reduces model loading latency compared to traditional remote/local pull methods. 2. Reduces Network and Storage Pressure Avoids the need to repeatedly download weight files from remote repositories, reducing the impact on centralized storage and network traffic, thereby enhancing overall system stability and service quality. 3. Improves Resource Utilization and Reduces Costs Accelerating the loading process reduces reliance on redundant computing pools, allowing computing resources to be elastically scaled and reclaimed as needed. 4. Enhances Business Continuity and High Availability In fault recovery scenarios, new instances can quickly take over existing services, avoiding prolonged business interruptions and improving the system's high availability and user experience. ### Does this PR introduce _any_ user-facing change? Netloader utilizes the existing --load-format=netloader and --model-loader-extra-config to be activated. The model-loader-extra-config needs to be input as a JSON string (as it is now) Afterwards, you can check whether the outputs for the same sentence are consistent when the temperature is set to 0. Signed-off-by: destinysky <kangrui10@126.com> - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
432
tests/ut/model_loader/netloader/test_netloader_elastic.py
Normal file
432
tests/ut/model_loader/netloader/test_netloader_elastic.py
Normal file
@@ -0,0 +1,432 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import vllm.logger
|
||||
|
||||
from vllm_ascend.model_loader.netloader.interaction.elastic import (
|
||||
ElasticClient, ElasticServer)
|
||||
|
||||
|
||||
# Simulate server's normal response
|
||||
def mock_server_response(data):
|
||||
return json.dumps({
|
||||
"label": "JOIN_ACK",
|
||||
"content": {
|
||||
"name": "mocked_name"
|
||||
}
|
||||
}).encode("utf-8")
|
||||
|
||||
|
||||
# Simulate server's error response
|
||||
def mock_server_error_response(data):
|
||||
return json.dumps({"label": "JOIN_ACK", "content": None}).encode("utf-8")
|
||||
|
||||
|
||||
# Simulated server's abnormal response
|
||||
def mock_server_exception_response(data):
|
||||
raise Exception("Mocked server exception")
|
||||
|
||||
|
||||
# Test the initialization of ElasticClient
|
||||
def test_elastic_client_init():
|
||||
sources = ["127.0.0.1:12345"]
|
||||
device_id = 0
|
||||
model_path = "mocked_model_path"
|
||||
tp = 1
|
||||
pp = 1
|
||||
|
||||
with patch('socket.socket') as mock_socket:
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.recv.return_value = mock_server_response(None)
|
||||
|
||||
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
|
||||
mock_socket_instance.__enter__.return_value = mock_socket_instance
|
||||
|
||||
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
|
||||
assert client.server_addr == "127.0.0.1"
|
||||
assert client.server_port == 12345
|
||||
assert client.ack == ("mocked_name", 12346)
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
|
||||
# Test the register method of ElasticClient
|
||||
def test_elastic_client_register():
|
||||
sources = ["127.0.0.1:12345"]
|
||||
device_id = 0
|
||||
model_path = "mocked_model_path"
|
||||
tp = 1
|
||||
pp = 1
|
||||
|
||||
with patch('socket.socket') as mock_socket:
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.connect.return_value = None
|
||||
mock_socket_instance.recv.return_value = mock_server_response(None)
|
||||
|
||||
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
|
||||
mock_socket_instance.__enter__.return_value = mock_socket_instance
|
||||
|
||||
client = ElasticClient(sources, device_id, model_path, tp, pp)
|
||||
assert client.register(device_id, model_path, tp,
|
||||
pp) == ("mocked_name", 12346)
|
||||
|
||||
|
||||
# Test the behavior of the `register` method of ElasticClient when the server returns an error response.
|
||||
def test_elastic_client_register_error_response():
|
||||
sources = ["127.0.0.1:12345"]
|
||||
device_id = 0
|
||||
model_path = "mocked_model_path"
|
||||
tp = 1
|
||||
pp = 1
|
||||
|
||||
with patch('socket.socket') as mock_socket:
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.connect.return_value = None
|
||||
mock_socket_instance.recv.return_value = mock_server_error_response(
|
||||
None)
|
||||
|
||||
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
|
||||
with pytest.raises(RuntimeError):
|
||||
client.register(device_id, model_path, tp, pp)
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
|
||||
# Test the behavior of the `register` method of ElasticClient when an exception is thrown on the server.
|
||||
def test_elastic_client_register_exception():
|
||||
sources = ["127.0.0.1:12345"]
|
||||
device_id = 0
|
||||
model_path = "mocked_model_path"
|
||||
tp = 1
|
||||
pp = 1
|
||||
|
||||
with patch('socket.socket') as mock_socket:
|
||||
mock_socket_instance = MagicMock()
|
||||
mock_socket.return_value = mock_socket_instance
|
||||
mock_socket_instance.connect.return_value = None
|
||||
mock_socket_instance.recv.side_effect = mock_server_exception_response
|
||||
mock_socket_instance.__enter__.return_value = mock_socket_instance
|
||||
mock_socket_instance.__exit__.return_value = None
|
||||
|
||||
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
|
||||
with pytest.raises(RuntimeError):
|
||||
client.register(device_id, model_path, tp, pp)
|
||||
mock_socket_instance.close.assert_called_once()
|
||||
|
||||
|
||||
class FakeInt8Param:
|
||||
|
||||
def __init__(self, name="param", device="npu", dtype=torch.int8):
|
||||
self.dtype = dtype
|
||||
self.device = torch.device(device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self # Simulate .data returning self so .cpu() etc. can be chained
|
||||
|
||||
def clone(self):
|
||||
return self
|
||||
|
||||
def detach(self):
|
||||
return self
|
||||
|
||||
def cpu(self):
|
||||
self.device = torch.device("cpu")
|
||||
return self
|
||||
|
||||
|
||||
class FakeModel:
|
||||
|
||||
def __init__(self):
|
||||
self.params = {
|
||||
"param1": MagicMock(dtype=torch.float32), # This will be ignored
|
||||
"param2": FakeInt8Param(), # This simulates a real int8 param
|
||||
}
|
||||
|
||||
def named_parameters(self):
|
||||
return self.params.items()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model():
|
||||
return FakeModel()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_config():
|
||||
return {
|
||||
"addr": "127.0.0.1",
|
||||
"port": 8080,
|
||||
"model": MagicMock(),
|
||||
"device_id": 0,
|
||||
"model_path": "/test/model",
|
||||
"tp": 1,
|
||||
"pp": 1,
|
||||
"int8_cache": "dram",
|
||||
'int8_cache_name': None
|
||||
}
|
||||
|
||||
|
||||
# Test server initialization
|
||||
def test_server_initialization(server_config, mock_model):
|
||||
server_config["model"] = mock_model
|
||||
with patch("socket.socket") as mock_socket:
|
||||
log_capture_string = io.StringIO()
|
||||
ch = logging.StreamHandler(log_capture_string)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
vllm.logger.logger.addHandler(ch)
|
||||
|
||||
server = ElasticServer(**server_config)
|
||||
|
||||
# Check the socket configuration
|
||||
mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM)
|
||||
mock_socket.return_value.setsockopt.assert_called_with(
|
||||
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080))
|
||||
mock_socket.return_value.listen.assert_called_with(256)
|
||||
|
||||
# Check int8 cache
|
||||
assert "param2" in server.original_int8
|
||||
assert server.original_int8[
|
||||
"param2"].device.type == "cpu" # Verifying DRAM Cache
|
||||
|
||||
assert server.addr == server_config['addr']
|
||||
assert server.port == server_config['port']
|
||||
assert server.device_id == server_config['device_id']
|
||||
assert server.model_path == server_config['model_path']
|
||||
assert server.tp == server_config['tp']
|
||||
assert server.pp == server_config['pp']
|
||||
|
||||
# Get captured logs
|
||||
log_output = log_capture_string.getvalue()
|
||||
vllm.logger.logger.removeHandler(ch)
|
||||
log_capture_string.close()
|
||||
|
||||
# Check output
|
||||
assert "Server 127.0.0.1:8080 starts" in log_output
|
||||
|
||||
|
||||
# Test the int8 cache option
|
||||
@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"),
|
||||
("no", None),
|
||||
("invalid", None)])
|
||||
def test_int8_cache_handling(server_config, mock_model, cache_option,
|
||||
expected_device, caplog):
|
||||
server_config["int8_cache"] = cache_option
|
||||
server_config["model"] = mock_model
|
||||
|
||||
with patch("socket.socket"):
|
||||
log_capture_string = io.StringIO()
|
||||
ch = logging.StreamHandler(log_capture_string)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
vllm.logger.logger.addHandler(ch)
|
||||
|
||||
server = ElasticServer(**server_config)
|
||||
|
||||
log_output = log_capture_string.getvalue()
|
||||
vllm.logger.logger.removeHandler(ch)
|
||||
log_capture_string.close()
|
||||
|
||||
if cache_option == "invalid":
|
||||
assert "int8_cache should be selected in [HBM, DRAM]" in log_output
|
||||
|
||||
if expected_device is None:
|
||||
assert len(server.original_int8) == 0
|
||||
else:
|
||||
assert server.original_int8[
|
||||
"param2"].device.type == expected_device
|
||||
|
||||
|
||||
# Test client processing
|
||||
def test_client_handler_valid_join(server_config, mock_model):
|
||||
server_config["model"] = mock_model
|
||||
with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend"
|
||||
) as mock_p2p_send:
|
||||
|
||||
# Create a simulated connection
|
||||
mock_conn = MagicMock()
|
||||
mock_addr = ("192.168.1.1", 12345)
|
||||
|
||||
# Configuring Client Data
|
||||
valid_data = {
|
||||
"label": "JOIN",
|
||||
"content": {
|
||||
"device_id": 0,
|
||||
"model_path": "/test/model",
|
||||
"tp": 1,
|
||||
"pp": 1,
|
||||
"port": 9090
|
||||
}
|
||||
}
|
||||
mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8")
|
||||
|
||||
# Start the server
|
||||
server = ElasticServer(**server_config)
|
||||
server.register_handler(mock_conn, mock_addr)
|
||||
|
||||
# Verify response
|
||||
expected_ack = {
|
||||
"label": "JOIN_ACK",
|
||||
"content": {
|
||||
"name": "192.168.1.1:12345"
|
||||
}
|
||||
}
|
||||
mock_conn.send.assert_called_once_with(
|
||||
json.dumps(expected_ack).encode("utf-8"))
|
||||
mock_p2p_send.assert_called_once_with("127.0.0.1", 9090,
|
||||
"192.168.1.1:12345")
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
|
||||
# Test mismatched JOIN requests
|
||||
def test_client_handler_mismatch(server_config):
|
||||
with patch("socket.socket"):
|
||||
server = ElasticServer(**server_config)
|
||||
mock_conn = MagicMock()
|
||||
mock_addr = ("192.168.1.1", 12345)
|
||||
|
||||
# Send mismatched data
|
||||
mismatch_data = {
|
||||
"label": "JOIN",
|
||||
"content": {
|
||||
"device_id": 1, # 不匹配的ID
|
||||
"model_path": "/wrong/model",
|
||||
"tp": 2,
|
||||
"pp": 2,
|
||||
"port": 9090
|
||||
}
|
||||
}
|
||||
mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8")
|
||||
|
||||
server.register_handler(mock_conn, mock_addr)
|
||||
|
||||
assert isinstance(mismatch_data["content"], dict)
|
||||
|
||||
# Verify response
|
||||
expected_ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}"
|
||||
}
|
||||
mock_conn.send.assert_called_once_with(
|
||||
json.dumps(expected_ack).encode("utf-8"))
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
|
||||
# Test Invalid Request
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_data,should_send",
|
||||
[
|
||||
(
|
||||
{
|
||||
"label": "WRONG_LABEL"
|
||||
}, True
|
||||
), # Incorrect label, can be decoded as JSON, but the content is invalid.
|
||||
(
|
||||
{
|
||||
"content": {
|
||||
"missing_fields": True
|
||||
}
|
||||
}, True
|
||||
), # Missing field, can be decoded as JSON, but the content is invalid.
|
||||
("plain text", False), # Non-JSON data, json.loads failed
|
||||
(b"invalid_bytes", False) # Invalid byte, decode or json.loads failed
|
||||
])
|
||||
def test_client_handler_invalid_requests(server_config, invalid_data,
|
||||
should_send):
|
||||
with patch("socket.socket"):
|
||||
log_capture_string = io.StringIO()
|
||||
ch = logging.StreamHandler(log_capture_string)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
vllm.logger.logger.addHandler(ch)
|
||||
|
||||
with patch("socket.socket"):
|
||||
server = ElasticServer(**server_config)
|
||||
mock_conn = MagicMock()
|
||||
mock_addr = ("192.168.1.1", 12345)
|
||||
|
||||
if isinstance(invalid_data, (str, bytes)):
|
||||
mock_conn.recv.return_value = invalid_data if isinstance(
|
||||
invalid_data, bytes) else invalid_data.encode()
|
||||
else:
|
||||
mock_conn.recv.return_value = json.dumps(invalid_data).encode(
|
||||
"utf-8")
|
||||
|
||||
server.register_handler(mock_conn, mock_addr)
|
||||
|
||||
if should_send:
|
||||
expected_ack = {
|
||||
"label":
|
||||
"JOIN_NACK",
|
||||
"content":
|
||||
f"Received data does not contain required fields: {invalid_data}"
|
||||
}
|
||||
mock_conn.send.assert_called_once_with(
|
||||
json.dumps(expected_ack).encode("utf-8"))
|
||||
else:
|
||||
mock_conn.send.assert_not_called()
|
||||
|
||||
log_output = log_capture_string.getvalue()
|
||||
vllm.logger.logger.removeHandler(ch)
|
||||
log_capture_string.close()
|
||||
|
||||
# Any warning in the log is acceptable
|
||||
assert "Failed to load" in log_output or "does not contain" in log_output
|
||||
mock_conn.close.assert_called_once()
|
||||
|
||||
|
||||
# Test the thread startup.
|
||||
def test_server_start(server_config):
|
||||
with patch("socket.socket"), \
|
||||
patch("threading.Thread") as mock_thread:
|
||||
|
||||
handler_thread_instance = mock_thread.return_value
|
||||
|
||||
server = ElasticServer(**server_config)
|
||||
server.start()
|
||||
|
||||
# Assert that the correct target parameter was passed when instantiating the Thread instance.
|
||||
mock_thread.assert_called_once()
|
||||
args, kwargs = mock_thread.call_args
|
||||
assert kwargs['target'] == server.elastic_client_handler
|
||||
|
||||
# Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment).
|
||||
assert handler_thread_instance.daemon is True
|
||||
|
||||
# Check if the start() method is called.
|
||||
handler_thread_instance.start.assert_called_once()
|
||||
|
||||
|
||||
# Test resource clearing
|
||||
def test_server_cleanup(server_config):
|
||||
with patch("socket.socket") as mock_socket:
|
||||
server = ElasticServer(**server_config)
|
||||
del server
|
||||
mock_socket.return_value.close.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
Reference in New Issue
Block a user