[Misc] Add a model loader that utilizes HCCL for weight loading (#2888)

### What this PR does / why we need it?

This PR introduces a new model loader called Netloader, which leverages
high-bandwidth P2P direct transfer between NPU cards to achieve weight
loading. Netloader is implemented as a plugin through the newly added
'register_model_loader' function in vLLM 0.10. It facilitates the
process of weight loading by sending weights from a pre-loaded model
(server) to an empty model of a newly started instance (client). The
server operates concurrently with normal inference tasks through
sub-threads and the 'stateless_init_torch_distributed_process_group' in
vLLM. The client initiates a transfer request after verifying that the
model and partitioning method are the same as the server's, and uses
HCCL's collective communication (send/recv) to load the weights in the
order they are stored in the model.

Application Scenarios:
1. Significantly Reduces Inference Instance Startup Time By reusing the
weights of already loaded instances and performing high-speed transfers
directly between computing cards, this method reduces model loading
latency compared to traditional remote/local pull methods.
2. Reduces Network and Storage Pressure Avoids the need to repeatedly
download weight files from remote repositories, reducing the impact on
centralized storage and network traffic, thereby enhancing overall
system stability and service quality.
3. Improves Resource Utilization and Reduces Costs Accelerating the
loading process reduces reliance on redundant computing pools, allowing
computing resources to be elastically scaled and reclaimed as needed.
4. Enhances Business Continuity and High Availability In fault recovery
scenarios, new instances can quickly take over existing services,
avoiding prolonged business interruptions and improving the system's
high availability and user experience.

### Does this PR introduce _any_ user-facing change?

Netloader utilizes the existing --load-format=netloader and
--model-loader-extra-config to be activated. The
model-loader-extra-config needs to be input as a JSON string (as it is
now)

Afterwards, you can check whether the outputs for the same sentence are
consistent when the temperature is set to 0.

Signed-off-by: destinysky <kangrui10@126.com>

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
Rui Kang
2025-10-23 15:56:07 +08:00
committed by GitHub
parent 807686dec9
commit 427b17e2da
19 changed files with 1999 additions and 1 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

View File

@@ -11,5 +11,6 @@ sleep_mode
structured_output
lora
eplb_swift_balancer
netloader
dynamic_batch
:::

View File

@@ -0,0 +1,97 @@
# Netloader Guide
This guide provides instructions for using **Netloader** as a weight-loader plugin for acceleration in **vLLM Ascend**.
---
## Overview
Netloader leverages high-bandwidth peer-to-peer (P2P) transfers between NPU cards to load model weights. It is implemented as a plugin (via the `register_model_loader` API added in vLLM 0.10). The workflow is:
1. A **server** preloads a model.
2. A new **client** instance requests weight transfer.
3. After validating that the model and partitioning match, the client uses HCCL collective communication (send/recv) to receive weights in the same order as stored in the model.
The server runs alongside normal inference tasks via sub-threads and via `stateless_init_torch_distributed_process_group` in vLLM. The client thus takes over weight initialization without needing to load from storage.
### Flowchart
![netloader flowchart](./images/netloader_flowchart.png)
### Timing Diagram
![netloader timing diagram](./images/netloader_timing_diagram.png)
### Application Scenarios
- **Reduce startup latency**: By reusing already loaded weights and transferring them directly between NPU cards, Netloader cuts down model loading time vs conventional remote/local pull strategies.
- **Relieve network & storage load**: Avoid repeated downloads of weight files from remote repositories, thus reducing pressure on central storage and network traffic.
- **Improve resource utilization & lower cost**: Faster loading allows less reliance on standby compute nodes; resources can be scaled up/down more flexibly.
- **Enhance business continuity & high availability**: In failure recovery, new instances can quickly take over without long downtime, improving system reliability and user experience.
---
## Usage
To enable Netloader, pass `--load-format=netloader` and provide configuration via `--model-loader-extra-config` (as a JSON string). Below are the supported configuration fields:
| Field Name | Type | Description | Allowed Values / Notes |
|--------------------|---------|------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------|
| **SOURCE** | List | Weighted data sources. Each item is a map with `device_id` and `sources`, specifying the rank and its endpoints (IP:port). <br>Example: `{"SOURCE": [{"device_id": 0, "sources": ["10.170.22.152:19374"]}, {"device_id": 1, "sources": ["10.170.22.152:11228"]}]}` <br>If omitted or empty, fallback to default loader. The SOURCE here is second priority. | A list of objects with keys `device_id: int` and `sources: List[str]` |
| **MODEL** | String | The model name, used to verify consistency between client and server. | Defaults to the `--model` argument if not specified. |
| **LISTEN_PORT** | Integer | Base port for the server listener. | The actual port = `LISTEN_PORT + RANK`. If omitted, a random valid port is chosen. Valid range: 102465535. If out of range, that server instance wont open a listener. |
| **INT8_CACHE** | String | Behavior for handling int8 parameters in quantized models. | One of `["hbm", "dram", "no"]`. <br> - `hbm`: copy original int8 parameters to high-bandwidth memory (HBM) (may cost a lot of HBM). <br> - `dram`: copy to DRAM. <br> - `no`: no special handling (may lead to divergence or unpredictable behavior). Default: `"no"`. |
| **INT8_CACHE_NAME** | List | Names of parameters to which `INT8_CACHE` is applied (i.e. filtering). | Default: `None` (means no filtering—all parameters). |
| **OUTPUT_PREFIX** | String | Prefix for writing per-rank listener address/port files in server mode. | If set, each rank writes to `{OUTPUT_PREFIX}{RANK}.txt` (text), content = `IP:Port`. |
| **CONFIG_FILE** | String | Path to a JSON file specifying the above configuration. | If provided, the SOURCE inside this file has **first priority** (overrides SOURCE in other configs). |
---
## Example Commands & Placeholders
> Replace parts in `` `<...>` `` before running.
### Server
```shell
VLLM_SLEEP_WHEN_IDLE=1 vllm serve `<model_file>` \
--tensor-parallel-size 1 \
--served-model-name `<model_name>` \
--enforce-eager \
--port `<port>` \
--load-format netloader
```
### Client
```shell
export NETLOADER_CONFIG='{"SOURCE":[{"device_id":0, "sources": ["`<server_IP>`:`<server_Port>`"]}]}'
VLLM_SLEEP_WHEN_IDLE=1 ASCEND_RT_VISIBLE_DEVICES=`<device_id_diff_from_server>` \
vllm serve `<model_file>` \
--tensor-parallel-size 1 \
--served-model-name `<model_name>` \
--enforce-eager \
--port `<client_port>` \
--load-format netloader \
--model-loader-extra-config="${NETLOADER_CONFIG}"
```
#### Placeholder Descriptions
- `<model_file>`: Path to the model file
- `<model_name>`: Model name (must match between server & client)
- `<port>`: Base listening port on server
- `<server_IP>` + `<server_Port>`: IP and port of the Netloader server (from server log)
- `<device_id_diff_from_server>`: Client device ID (must differ from servers)
- `<client_port>`: Port on which client listens
After startup, you can test consistency by issuing inference requests with temperature = 0 and comparing outputs.
---
## Note & Caveats
- If Netloader is used, **each worker process** must bind a listening port. That port may be user-specified or assigned randomly. If user-specified, ensure it is available.
- Netloader requires extra HBM memory to establish HCCL connections (i.e. `HCCL_BUFFERSIZE`, default ~200 MB). Users should reserve sufficient capacity (e.g. via `--gpu-memory-utilization`).
- It is recommended to set `VLLM_SLEEP_WHEN_IDLE=1` to mitigate unstable or slow connections/transmissions. Related info: [vLLM Issue #16660](https://github.com/vllm-project/vllm/issues/16660), [vLLM PR #16226](https://github.com/vllm-project/vllm/pull/16226).

View File

@@ -393,7 +393,8 @@ setup(
"vllm.platform_plugins": ["ascend = vllm_ascend:register"],
"vllm.general_plugins": [
"ascend_enhanced_model = vllm_ascend:register_model",
"ascend_kv_connector = vllm_ascend:register_connector"
"ascend_kv_connector = vllm_ascend:register_connector",
"ascend_model_loader = vllm_ascend:register_model_loader"
],
},
)

View File

@@ -0,0 +1,215 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch import nn
from vllm_ascend.model_loader.netloader.netloader import ModelNetLoaderElastic
class DummyDeviceConfig:
device = 'cuda'
device_type = 'cuda'
class DummyParallelConfig:
tensor_parallel_size = 1
pipeline_parallel_size = 1
class DummyVllmConfig:
device_config = DummyDeviceConfig()
parallel_config = DummyParallelConfig()
additional_config = None
class DummyModelConfig:
model = 'dummy-model'
dtype = torch.float32
@pytest.fixture
def default_load_config():
class DummyLoadConfig:
model_loader_extra_config = None
load_format = "default"
return DummyLoadConfig()
def make_loader_with_config(extra):
class DummyLoadConfig:
model_loader_extra_config = extra
load_format = "default"
return ModelNetLoaderElastic(DummyLoadConfig())
def test_init_with_extra_config_file(tmp_path, monkeypatch):
# Generate test JSON file
config_content = {
"SOURCE": [{
"device_id": 0
}],
"MODEL": "foo-model",
"LISTEN_PORT": 5001,
"INT8_CACHE": "hbm",
"OUTPUT_PREFIX": str(tmp_path),
}
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps(config_content))
dummy_logger = MagicMock()
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
lambda x: True)
extra = {"CONFIG_FILE": str(config_file)}
loader = make_loader_with_config(extra)
assert loader.model_path == "foo-model"
assert loader.source == [{"device_id": 0}]
assert loader.listen_port == 5001
assert loader.int8_cache == "hbm"
assert loader.output_prefix == str(tmp_path)
def test_init_with_extra_config(monkeypatch):
dummy_logger = MagicMock()
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
lambda x: True)
extra = {
"SOURCE": [{
"device_id": 0
}],
"MODEL": "foo",
"LISTEN_PORT": "4000",
"INT8_CACHE": "dram",
"OUTPUT_PREFIX": "/tmp/"
}
loader = make_loader_with_config(extra)
assert loader.model_path == "foo"
assert loader.listen_port == 4000
assert loader.int8_cache == "dram"
assert loader.output_prefix == "/tmp/"
assert loader.source == [{"device_id": 0}]
def test_init_with_invalid_config(monkeypatch):
dummy_logger = MagicMock()
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
lambda x: False)
# c
extra = {
"SOURCE": None,
"MODEL": None,
"LISTEN_PORT": None,
"INT8_CACHE": "something",
"OUTPUT_PREFIX": None,
}
loader = make_loader_with_config(extra)
assert loader.model_path is None
assert loader.listen_port is None
assert loader.int8_cache == "no"
assert loader.output_prefix is None
@patch("vllm_ascend.model_loader.netloader.netloader.logger")
def test_load_model_elastic_success(mock_logger, monkeypatch, tmp_path):
monkeypatch.setattr("torch.distributed.get_rank", lambda: 0)
class FakeContext:
def __enter__(self):
pass
def __exit__(self, a, b, c):
pass
monkeypatch.setattr("torch.device", lambda d: FakeContext())
# patch deep copy
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x)
# patch set_default_torch_dtype
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype",
lambda dtype: FakeContext())
# patch initialize_model
dummy_model = MagicMock(spec=nn.Module)
dummy_model.eval.return_value = dummy_model
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.initialize_model",
lambda **kwargs: dummy_model)
# patch elastic_load
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.elastic_load",
lambda **kwargs: dummy_model)
# patch process_weights_after_loading
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading",
lambda *a, **k: None)
# patch get_ip
monkeypatch.setattr("vllm.utils.get_ip", lambda: "127.0.0.1")
# patch find_free_port
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.find_free_port",
lambda: 8888)
# patch ElasticServer
class DummyElasticServer:
def __init__(*a, **k):
pass
def start(self):
pass
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.ElasticServer",
DummyElasticServer)
# write output_prefix to the temporary directory
extra = {
"SOURCE": [{
"device_id": 0
}],
"MODEL": "foo",
"LISTEN_PORT": 5555,
"OUTPUT_PREFIX": str(tmp_path) + "/output_",
"INT8_CACHE": "no"
}
loader = make_loader_with_config(extra)
vllm_config = DummyVllmConfig()
model_config = DummyModelConfig()
result = loader.load_model(vllm_config, model_config)
assert isinstance(result, nn.Module)
# Check file
written_file = tmp_path / "output_0.txt"
assert written_file.exists()
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,432 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import json
import logging
import socket
from unittest.mock import MagicMock, patch
import pytest
import torch
import vllm.logger
from vllm_ascend.model_loader.netloader.interaction.elastic import (
ElasticClient, ElasticServer)
# Simulate server's normal response
def mock_server_response(data):
return json.dumps({
"label": "JOIN_ACK",
"content": {
"name": "mocked_name"
}
}).encode("utf-8")
# Simulate server's error response
def mock_server_error_response(data):
return json.dumps({"label": "JOIN_ACK", "content": None}).encode("utf-8")
# Simulated server's abnormal response
def mock_server_exception_response(data):
raise Exception("Mocked server exception")
# Test the initialization of ElasticClient
def test_elastic_client_init():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.recv.return_value = mock_server_response(None)
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
mock_socket_instance.__enter__.return_value = mock_socket_instance
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
assert client.server_addr == "127.0.0.1"
assert client.server_port == 12345
assert client.ack == ("mocked_name", 12346)
mock_socket_instance.close.assert_called_once()
# Test the register method of ElasticClient
def test_elastic_client_register():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.return_value = mock_server_response(None)
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
mock_socket_instance.__enter__.return_value = mock_socket_instance
client = ElasticClient(sources, device_id, model_path, tp, pp)
assert client.register(device_id, model_path, tp,
pp) == ("mocked_name", 12346)
# Test the behavior of the `register` method of ElasticClient when the server returns an error response.
def test_elastic_client_register_error_response():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.return_value = mock_server_error_response(
None)
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
with pytest.raises(RuntimeError):
client.register(device_id, model_path, tp, pp)
mock_socket_instance.close.assert_called_once()
# Test the behavior of the `register` method of ElasticClient when an exception is thrown on the server.
def test_elastic_client_register_exception():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.side_effect = mock_server_exception_response
mock_socket_instance.__enter__.return_value = mock_socket_instance
mock_socket_instance.__exit__.return_value = None
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
with pytest.raises(RuntimeError):
client.register(device_id, model_path, tp, pp)
mock_socket_instance.close.assert_called_once()
class FakeInt8Param:
def __init__(self, name="param", device="npu", dtype=torch.int8):
self.dtype = dtype
self.device = torch.device(device)
@property
def data(self):
return self # Simulate .data returning self so .cpu() etc. can be chained
def clone(self):
return self
def detach(self):
return self
def cpu(self):
self.device = torch.device("cpu")
return self
class FakeModel:
def __init__(self):
self.params = {
"param1": MagicMock(dtype=torch.float32), # This will be ignored
"param2": FakeInt8Param(), # This simulates a real int8 param
}
def named_parameters(self):
return self.params.items()
@pytest.fixture
def mock_model():
return FakeModel()
@pytest.fixture
def server_config():
return {
"addr": "127.0.0.1",
"port": 8080,
"model": MagicMock(),
"device_id": 0,
"model_path": "/test/model",
"tp": 1,
"pp": 1,
"int8_cache": "dram",
'int8_cache_name': None
}
# Test server initialization
def test_server_initialization(server_config, mock_model):
server_config["model"] = mock_model
with patch("socket.socket") as mock_socket:
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
server = ElasticServer(**server_config)
# Check the socket configuration
mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM)
mock_socket.return_value.setsockopt.assert_called_with(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080))
mock_socket.return_value.listen.assert_called_with(256)
# Check int8 cache
assert "param2" in server.original_int8
assert server.original_int8[
"param2"].device.type == "cpu" # Verifying DRAM Cache
assert server.addr == server_config['addr']
assert server.port == server_config['port']
assert server.device_id == server_config['device_id']
assert server.model_path == server_config['model_path']
assert server.tp == server_config['tp']
assert server.pp == server_config['pp']
# Get captured logs
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
# Check output
assert "Server 127.0.0.1:8080 starts" in log_output
# Test the int8 cache option
@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"),
("no", None),
("invalid", None)])
def test_int8_cache_handling(server_config, mock_model, cache_option,
expected_device, caplog):
server_config["int8_cache"] = cache_option
server_config["model"] = mock_model
with patch("socket.socket"):
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
server = ElasticServer(**server_config)
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
if cache_option == "invalid":
assert "int8_cache should be selected in [HBM, DRAM]" in log_output
if expected_device is None:
assert len(server.original_int8) == 0
else:
assert server.original_int8[
"param2"].device.type == expected_device
# Test client processing
def test_client_handler_valid_join(server_config, mock_model):
server_config["model"] = mock_model
with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend"
) as mock_p2p_send:
# Create a simulated connection
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
# Configuring Client Data
valid_data = {
"label": "JOIN",
"content": {
"device_id": 0,
"model_path": "/test/model",
"tp": 1,
"pp": 1,
"port": 9090
}
}
mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8")
# Start the server
server = ElasticServer(**server_config)
server.register_handler(mock_conn, mock_addr)
# Verify response
expected_ack = {
"label": "JOIN_ACK",
"content": {
"name": "192.168.1.1:12345"
}
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
mock_p2p_send.assert_called_once_with("127.0.0.1", 9090,
"192.168.1.1:12345")
mock_conn.close.assert_called_once()
# Test mismatched JOIN requests
def test_client_handler_mismatch(server_config):
with patch("socket.socket"):
server = ElasticServer(**server_config)
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
# Send mismatched data
mismatch_data = {
"label": "JOIN",
"content": {
"device_id": 1, # 不匹配的ID
"model_path": "/wrong/model",
"tp": 2,
"pp": 2,
"port": 9090
}
}
mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8")
server.register_handler(mock_conn, mock_addr)
assert isinstance(mismatch_data["content"], dict)
# Verify response
expected_ack = {
"label":
"JOIN_NACK",
"content":
f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}"
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
mock_conn.close.assert_called_once()
# Test Invalid Request
@pytest.mark.parametrize(
"invalid_data,should_send",
[
(
{
"label": "WRONG_LABEL"
}, True
), # Incorrect label, can be decoded as JSON, but the content is invalid.
(
{
"content": {
"missing_fields": True
}
}, True
), # Missing field, can be decoded as JSON, but the content is invalid.
("plain text", False), # Non-JSON data, json.loads failed
(b"invalid_bytes", False) # Invalid byte, decode or json.loads failed
])
def test_client_handler_invalid_requests(server_config, invalid_data,
should_send):
with patch("socket.socket"):
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
with patch("socket.socket"):
server = ElasticServer(**server_config)
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
if isinstance(invalid_data, (str, bytes)):
mock_conn.recv.return_value = invalid_data if isinstance(
invalid_data, bytes) else invalid_data.encode()
else:
mock_conn.recv.return_value = json.dumps(invalid_data).encode(
"utf-8")
server.register_handler(mock_conn, mock_addr)
if should_send:
expected_ack = {
"label":
"JOIN_NACK",
"content":
f"Received data does not contain required fields: {invalid_data}"
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
else:
mock_conn.send.assert_not_called()
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
# Any warning in the log is acceptable
assert "Failed to load" in log_output or "does not contain" in log_output
mock_conn.close.assert_called_once()
# Test the thread startup.
def test_server_start(server_config):
with patch("socket.socket"), \
patch("threading.Thread") as mock_thread:
handler_thread_instance = mock_thread.return_value
server = ElasticServer(**server_config)
server.start()
# Assert that the correct target parameter was passed when instantiating the Thread instance.
mock_thread.assert_called_once()
args, kwargs = mock_thread.call_args
assert kwargs['target'] == server.elastic_client_handler
# Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment).
assert handler_thread_instance.daemon is True
# Check if the start() method is called.
handler_thread_instance.start.assert_called_once()
# Test resource clearing
def test_server_cleanup(server_config):
with patch("socket.socket") as mock_socket:
server = ElasticServer(**server_config)
del server
mock_socket.return_value.close.assert_called_once()
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,114 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock, patch
import pytest
from vllm_ascend.model_loader.netloader.load import elastic_load
@pytest.fixture
def mock_sources():
return [
{
"device_id": 0,
"sources": ["a", "b"]
},
{
"device_id": 1,
"sources": ["c"]
},
]
@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
def test_sources_this_device_empty(mock_p2p, mock_client):
sources = [{"device_id": 1, "sources": ["c"]}]
result = elastic_load("model", 0, "model_path", sources, 1, 1)
assert result is None
mock_client.assert_not_called()
mock_p2p.assert_not_called()
@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
def test_client_s_none(mock_p2p, mock_client, mock_sources):
# Simulate ElasticClient.s as None
mock_instance = MagicMock()
mock_instance.s = None
mock_client.return_value = mock_instance
result = elastic_load("model", 0, "model_path", mock_sources, 1, 1)
assert result is None
@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
def test_client_ack_none(mock_p2p, mock_client, mock_sources):
# Simulate ElasticClient.ack as None
mock_instance = MagicMock()
mock_instance.s = True
mock_instance.ack = None
mock_client.return_value = mock_instance
result = elastic_load("model", 0, "model_path", mock_sources, 1, 1)
assert result is None
@patch("vllm_ascend.model_loader.netloader.load.P2PLoad")
@patch("vllm_ascend.model_loader.netloader.load.logger")
def test_model_load_fail(mock_logger, mock_p2p):
mock_client = MagicMock()
mock_client.s = True
mock_client.ack = ["foo", "bar"]
mock_client.server_addr = "addr"
with patch("vllm_ascend.model_loader.netloader.load.ElasticClient",
return_value=mock_client):
# P2PLoad.load returns None
mock_p2p_instance = MagicMock()
mock_p2p_instance.load.return_value = None
mock_p2p.return_value = mock_p2p_instance
sources = [{"device_id": 0, "sources": ["whatever"]}]
result = elastic_load("model", 0, "model_path", sources, 1, 1)
assert result is None
mock_logger.error.assert_called_once()
@patch("vllm_ascend.model_loader.netloader.load.P2PLoad")
@patch("vllm_ascend.model_loader.netloader.load.logger")
def test_model_load_success(mock_logger, mock_p2p):
mock_client = MagicMock()
mock_client.s = True
mock_client.ack = ["foo", "bar"]
mock_client.server_addr = "addr"
with patch("vllm_ascend.model_loader.netloader.load.ElasticClient",
return_value=mock_client):
expected_model = object()
mock_p2p_instance = MagicMock()
mock_p2p_instance.load.return_value = expected_model
mock_p2p.return_value = mock_p2p_instance
sources = [{"device_id": 0, "sources": ["whatever"]}]
result = elastic_load("model", 0, "model_path", sources, 1, 1)
assert result is expected_model
mock_logger.info.assert_called_once()
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,61 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import tempfile
from unittest.mock import patch
import pytest
from vllm_ascend.model_loader.netloader.utils import (find_free_port,
is_valid_path_prefix)
def test_find_free_port():
port = find_free_port()
assert isinstance(port, int)
assert port > 0
def test_is_valid_path_prefix_empty():
assert not is_valid_path_prefix('')
def test_is_valid_path_prefixIllegal_characters():
assert not is_valid_path_prefix('test<>:"|?*')
def test_is_valid_path_prefixRelative_path():
assert is_valid_path_prefix('test')
def test_is_valid_path_prefixAbsolute_path():
with tempfile.TemporaryDirectory() as tmpdir:
assert is_valid_path_prefix(os.path.join(tmpdir, 'test'))
@patch('os.path.exists', return_value=False)
def test_is_valid_path_prefix_no_directory(mock_exists):
assert not is_valid_path_prefix('/nonexistent_dir/test')
@patch('os.path.exists', return_value=True)
def test_is_valid_path_prefix_directory_exists(mock_exists):
assert is_valid_path_prefix('/existing_dir/test')
if __name__ == "__main__":
pytest.main()

View File

@@ -31,3 +31,8 @@ def register_model():
def register_connector():
from vllm_ascend.distributed import register_connector
register_connector()
def register_model_loader():
from .model_loader.netloader import register_netloader
register_netloader()

View File

View File

@@ -0,0 +1,20 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
def register_netloader():
"""Register the NetLoader plugin."""
from .netloader import ModelNetLoaderElastic # noqa

View File

@@ -0,0 +1,170 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import torch
import torch_npu
from vllm.distributed.utils import (
stateless_destroy_torch_distributed_process_group,
stateless_init_torch_distributed_process_group)
from vllm.logger import logger
class P2PLoad:
"""
Class for receiving model parameters in a distributed manner using HCCL backend.
"""
def __init__(
self,
world_name: str,
source_ip: str,
source_port: int,
):
"""
Initializes the P2PLoad instance.
Parameters:
- world_name: The name of the distributed group.
- source_ip: The IP address of the source node.
- source_port: The port number for the source node.
"""
self.world_name = world_name
self.source_ip = source_ip
self.source_port = source_port
def load(self, model):
"""
Loads the model parameters using HCCL backend.
Parameters:
- model: The model whose parameters are to be loaded.
Returns:
- The model if loading is successful, otherwise None.
"""
model_device = next(model.parameters()).device
logger.info(
f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
receiver_pg = None
loaded_model = None
try:
receiver_pg = stateless_init_torch_distributed_process_group(
host=self.world_name.split(":")[0],
port=self.source_port,
rank=0,
world_size=2,
backend='hccl',
)
logger.info(
f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
logger.info(
f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
logger.info(f"Model device: {model_device}")
trans_stream = torch_npu.npu.Stream()
with torch_npu.npu.stream(trans_stream):
for name, param in model.named_parameters():
if len(param.shape) == 0:
continue
receiver_pg.recv([param], 1, 0).wait()
torch.distributed.barrier(group=receiver_pg,
device_ids=[model_device.index])
torch_npu.npu.synchronize(trans_stream)
logger.info(
f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
)
loaded_model = model
except Exception as e:
logger.error("Failed to recv model: {}".format(e))
finally:
if receiver_pg:
stateless_destroy_torch_distributed_process_group(receiver_pg)
return loaded_model
class P2PSend:
"""
Class for sending model parameters in a distributed manner using HCCL backend.
"""
def __init__(self, listen_ip: str, listen_port: int, comm_name: str):
"""
Initializes the P2PSend instance.
Parameters:
- listen_ip: The IP address to listen on.
- listen_port: The port number to listen on.
- comm_name: The name of the communication group.
"""
self.listen_ip = listen_ip
self.listen_port = listen_port
self.comm_name = comm_name
def send(self, model, int8_params: dict):
"""
Sends the model parameters using HCCL backend.
Parameters:
- model: The model whose parameters are to be sent.
- int8_params: Dictionary of parameters that are in int8 format.
"""
model_device = next(model.parameters()).device
torch.npu.set_device(model_device)
logger.info(
f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
sender_pg = None
try:
sender_pg = stateless_init_torch_distributed_process_group(
host=self.comm_name.split(":")[0],
port=self.listen_port,
rank=1,
world_size=2,
backend='hccl',
)
logger.info(
f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
logger.info(
f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
logger.info(f"Model device: {model_device}")
trans_stream = torch_npu.npu.Stream()
with torch_npu.npu.stream(trans_stream):
for name, param in model.named_parameters():
if "aclnn_input_scale" in name:
continue
if name in int8_params:
sender_pg.send([int8_params[name].to(model_device)], 0,
0).wait()
else:
sender_pg.send([param.contiguous()], 0, 0).wait()
torch.distributed.barrier(group=sender_pg,
device_ids=[model_device.index])
torch_npu.npu.synchronize(trans_stream)
logger.info(
f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
)
finally:
if sender_pg:
stateless_destroy_torch_distributed_process_group(sender_pg)

View File

@@ -0,0 +1,408 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
import re
import socket
import threading
from typing import List, Optional, Tuple
import torch
from vllm.logger import logger
from ..executor.elastic_load import P2PSend
from ..utils import find_free_port
class ElasticClient:
"""
Class for handling the client-side logic of Netloader of models.
"""
def __init__(self, sources: list[str], device_id: int, model_path: str,
tp: int, pp: int):
"""
Initializes the ElasticClient instance.
Parameters:
- sources: List of source addresses in the format IP:port.
- device_id: The ID of the current device.
- model_path: The path to the model.
- tp: Tensor parallel size.
- pp: Pipeline parallel size.
"""
self.sources = sources
self.device_id = device_id
self.model_path = model_path
self.tp = tp
self.pp = pp
self.s: Optional[socket.socket] = None
self.ack: Optional[Tuple[str, int]] = None
self.server_addr: Optional[str] = None
self.server_port: Optional[int] = None
for source in self.sources:
try:
ip, port_str = source.split(':')
port = int(port_str)
except Exception as e:
logger.error(f"IP format error: {source}, detail: {e}")
continue
self.server_addr = ip
self.server_port = port
try:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
logger.info(
f"Start connection to server: {self.server_addr}:{self.server_port}"
)
sock.connect((self.server_addr, self.server_port))
logger.info(
f"Finish connection to server: {self.server_addr}:{self.server_port}"
)
sock.settimeout(60)
self.s = sock
self.ack = self.register(device_id, model_path, tp, pp)
break
except Exception as e:
logger.error(f"Connect to {source} fails, detail: {e}")
if sock is not None:
try:
sock.close()
except Exception:
pass
self.s = None
self.ack = None
self.server_addr = None
self.server_port = None
def close(self) -> None:
"""
Closes the socket connection.
"""
if self.s is not None:
try:
self.s.close()
except Exception as e:
logger.error(f"Error closing socket: {e}")
finally:
self.s = None
def __enter__(self) -> "ElasticClient":
"""
Context manager enter method.
"""
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""
Context manager exit method.
"""
self.close()
def __del__(self):
"""
Destructor method to ensure socket is closed.
"""
try:
self.close()
except Exception:
pass
def send_str(self, data_str: str) -> None:
"""
Sends a string over the socket connection.
Parameters:
- data_str: The string to be sent.
"""
if self.s is None:
raise RuntimeError("Socket was not created correctly.")
self.s.send(data_str.encode("utf-8"))
def recv_str(self, buffer_size: int = 1024) -> str:
"""
Receives a string over the socket connection.
Parameters:
- buffer_size: The size of the buffer for receiving data.
Returns:
- The received string.
"""
if self.s is None:
raise RuntimeError("Socket was not created correctly.")
data_str = self.s.recv(buffer_size).decode("utf-8")
return data_str
def register(self, device_id: int, model_path: str, tp: int,
pp: int) -> Tuple[str, int]:
"""
Registers the client with the server.
Parameters:
- device_id: The ID of the current device.
- model_path: The path to the model.
- tp: Tensor parallel size.
- pp: Pipeline parallel size.
Returns:
- A tuple containing the communication name and port.
"""
free_port = find_free_port()
data = {
"label": "JOIN",
"content": {
'device_id': device_id,
'model_path': model_path,
'tp': tp,
'pp': pp,
'port': free_port
}
}
try:
self.send_str(json.dumps(data))
except Exception as e:
raise RuntimeError(
f"Send data {data} to server fails, detail: {e}")
try:
ack_str = self.recv_str()
except Exception as e:
raise RuntimeError(f"Receive data from server fails, detail: {e}")
try:
ack = json.loads(ack_str)
except Exception as e:
raise RuntimeError(
f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
)
logger.info(f"Receive ack: {ack}")
if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack
and ack["content"] is not None and "name" in ack["content"]):
return (ack["content"]["name"], free_port)
elif ("label" in ack and ack["label"] == 'JOIN_NACK'
and "content" in ack):
raise RuntimeError(
f"Receive nack from server, reason: {ack['content']}")
else:
raise RuntimeError(
f"Receive ack {ack} from server does not contain required fields"
)
class ElasticServer:
"""
Class for handling the server-side logic of Netloader of models.
"""
def __init__(self, addr: str, port: int, model, device_id: int,
model_path: str, tp: int, pp: int, int8_cache: str,
int8_cache_name: Optional[List[str]]):
"""
Initializes the ElasticServer instance.
Parameters:
- addr: The IP address to listen on.
- port: The port number to listen on.
- model: The model to be served.
- device_id: The ID of the current device (i.e. global rank).
- model_path: The path to the model.
- tp: Tensor parallel size.
- pp: Pipeline parallel size.
- int8_cache: The type of caching for int8 parameters (HBM, DRAM, or no).
- int8_cache_name: List of parameter names to be cached.
"""
self.addr = addr
self.port = port
self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.s.bind((self.addr, self.port))
self.s.listen(256)
self.model = model
self.device_id = device_id
self.model_path = model_path
self.tp = tp
self.pp = pp
self.original_int8 = {}
int8_pattern = "|".join(
map(re.escape,
int8_cache_name)) if int8_cache_name is not None else "(?:)"
for name, param in self.model.named_parameters():
if param.dtype == torch.int8:
if int8_cache == 'hbm':
if int8_cache_name is None or (
int8_cache_name is not None
and re.search(int8_pattern, name) is not None):
try:
self.original_int8[name] = param.data.clone(
).detach()
except RuntimeError as e:
logger.error(
f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
)
self.original_int8[name] = param.data.cpu()
elif int8_cache == 'dram':
if int8_cache_name is None or (
int8_cache_name is not None
and re.search(int8_pattern, name) is not None):
self.original_int8[name] = param.data.cpu()
elif int8_cache == 'no':
pass
else:
logger.warning(
f"int8_cache should be selected in [HBM, DRAM], but got {int8_cache}, change to no cache"
)
logger.info(
f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}"
)
def __del__(self):
"""
Destructor method to ensure socket is closed.
"""
self.s.close()
def start(self):
"""
Starts the server to handle incoming connections.
"""
handler_thread = threading.Thread(target=self.elastic_client_handler)
handler_thread.daemon = True
handler_thread.start()
def elastic_client_handler(self):
"""
Handles incoming client connections.
"""
while True:
conn, addr = self.s.accept()
logger.info("Accept new connection from {}:{}...".format(*addr))
self.register_handler(conn, addr)
def register_handler(self, conn, addr, buffer_size=1024):
"""
Handles the registration of a client.
Parameters:
- conn: The connection socket.
- addr: The address of the client.
- buffer_size: The size of the buffer for receiving data.
"""
data_str = conn.recv(buffer_size).decode("utf-8")
if not data_str:
return
try:
data = json.loads(data_str)
except Exception:
logger.error(f"Failed to load {data_str} as JSON string")
conn.close()
return
def is_valid_data(data):
"""
Validates the received data.
Parameters:
- data: The data to be validated.
Returns:
- True if the data is valid, otherwise False.
"""
if not isinstance(data, dict):
return False
if data.get("label") != "JOIN":
return False
content = data.get("content")
if not isinstance(content, dict):
return False
required_keys = ["device_id", "model_path", "tp", "pp", "port"]
if not all(k in content for k in required_keys):
return False
port = content["port"]
if not (isinstance(port, int) or
(isinstance(port, str) and port.isdigit())):
return False
return True
comm_name = None
if is_valid_data(data):
device_id = int(data["content"]["device_id"])
model_path = data["content"]["model_path"]
tp = int(data["content"]["tp"])
pp = int(data["content"]["pp"])
if int(self.device_id
) == device_id and self.model_path == model_path and int(
self.tp) == tp and int(self.pp) == pp:
comm_name = str(addr[0]) + ":" + str(addr[1])
ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
else:
logger.warning(
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
)
ack = {
"label":
"JOIN_NACK",
"content":
f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
}
else:
logger.warning(
f"Received data does not contain required fields: {data}")
ack = {
"label":
"JOIN_NACK",
"content":
f"Received data does not contain required fields: {data}"
}
try:
ack_str = json.dumps(ack).encode("utf-8")
except Exception as e:
logger.error(
f"Failed to convert {ack} to JSON format, details: {e}")
conn.close()
return
try:
conn.send(ack_str)
except Exception as e:
logger.error(f"Failed to send {ack} to {addr}, details: {e}")
conn.close()
return
if ack["content"] and isinstance(ack["content"],
dict) and 'name' in ack["content"]:
try:
p2psend = P2PSend(self.addr, data["content"]["port"],
ack["content"]["name"])
p2psend.send(self.model, self.original_int8)
except Exception as e:
logger.error(
f"P2PSend Failed to send model to {self.addr}, details: {e}"
)
conn.close()

View File

@@ -0,0 +1,84 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from vllm.logger import logger
from .executor.elastic_load import P2PLoad
from .interaction.elastic import ElasticClient
def elastic_load(
model,
device_id: int,
model_path: str,
sources: list,
tp: int,
pp: int,
):
"""
Loads a model using elastic loading across multiple devices.
Parameters:
- model: The model instance to be loaded.
- device_id: The ID of the current device (i.e. global rank).
- model_path: The path to the model file.
- sources: A list of source configurations, each containing device_id and sources.
- tp: Tensor parallel size, indicating the number of devices for tensor parallelism.
- pp: Pipeline parallel size, indicating the number of devices for pipeline parallelism.
Returns:
- The loaded model if successful, otherwise None.
"""
# Filter sources for the current device
sources_this_device = []
for s in sources:
if isinstance(
s, dict
) and "device_id" in s and s["device_id"] == device_id and isinstance(
s["sources"], list):
sources_this_device += s["sources"]
if len(sources_this_device) == 0:
return None
try:
# Initialize the interaction layer with the ElasticClient
with ElasticClient(sources_this_device, device_id, model_path, tp,
pp) as client_interaction_layer:
if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
raise RuntimeError(
"Failed to initialize ElasticClient: socket or server_addr is None"
)
ack = client_interaction_layer.ack
if ack is None:
raise RuntimeError("ElasticClient.register did not return ack")
t0 = time.perf_counter()
elastic_loader = P2PLoad(ack[0],
client_interaction_layer.server_addr,
ack[1])
model_loaded = elastic_loader.load(model=model)
if model_loaded is None:
logger.error("Failed to load model")
return None
logger.info("Finish elastic load (duration: {}s)".format(
time.perf_counter() - t0))
return model_loaded
except Exception as e:
logger.error(f"elastic_load error: {e}")
return None

View File

@@ -0,0 +1,324 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import gc
import json
import time
from copy import deepcopy
from typing import List, Optional, Tuple
import torch
from torch import nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.logger import logger
from vllm.model_executor.model_loader import register_model_loader
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from .interaction.elastic import ElasticServer
from .load import elastic_load
from .utils import find_free_port, is_valid_path_prefix
@register_model_loader("netloader")
class ModelNetLoaderElastic(BaseModelLoader):
"""
A model loader that uses elastic loading for loading weights.
"""
source: Optional[List[dict]]
model_path: Optional[str]
listen_port: Optional[int]
int8_cache: str
int8_cache_name: Optional[List[str]]
output_prefix: Optional[str]
def __init__(self, load_config: LoadConfig):
"""
Initializes the ModelNetLoaderElastic with configuration.
Parameters:
- load_config: Configuration for loading the model.
"""
super().__init__(load_config)
config = None
# Try to read config file at first
extra = load_config.model_loader_extra_config
if extra and "CONFIG_FILE" in extra:
try:
logger.info(
f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..."
)
with open(extra["CONFIG_FILE"], 'r') as f:
config = json.load(f)
except FileNotFoundError:
logger.error("CONFIG_FILE not found")
except json.JSONDecodeError:
logger.error("CONFIG_FILE is not a valid JSON file")
except Exception as e:
logger.error(
f"Unexpected error while reading CONFIG_FILE: {e}")
if config is None and extra:
logger.info("Reading configs in model_loader_extra_config ...")
config = extra
config = config or {}
for key, attr, checker, caster, default in [
("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v,
None),
("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v,
None),
("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or
(isinstance(v, str) and v.isdigit()), lambda v: int(v), None),
("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v.
lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'),
("INT8_CACHE_NAME", "int8_cache_name",
lambda v: isinstance(v, list), lambda v: v, None),
("OUTPUT_PREFIX", "output_prefix",
lambda v: isinstance(v, str) and is_valid_path_prefix(v),
lambda v: v, None),
]:
v = config.get(key, default)
if not checker(v):
v = default
else:
v = caster(v)
setattr(self, attr, v)
logger.info(
"Initializing elastic Netloader with config: "
"MODEL=%s, LISTEN_PORT=%s,"
"SOURCE=%s, INT8_CACHE=%s, INT8_CACHE_NAME=%s,"
"OUTPUT_PREFIX=%s)",
self.model_path,
self.listen_port,
self.source,
self.int8_cache,
self.int8_cache_name,
self.output_prefix,
)
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""
Loads the model using the specified configuration.
Parameters:
- vllm_config: Configuration for the VLLM.
- model_config: Configuration for the model.
Returns:
- The loaded model.
"""
device_config = vllm_config.device_config
parallel_config = vllm_config.parallel_config
need_process_weights_after_loading = False
if self.model_path is None:
self.model_path = model_config.model
logger.info(f"model_path is set to {self.model_path}")
device_id = torch.distributed.get_rank()
if (self.source is None or not isinstance(self.source, list)
or device_id not in [
one_device["device_id"] for one_device in self.source if
isinstance(one_device, dict) and "device_id" in one_device
]):
logger.warning(
"Did not get valid source info, use DefaultModelLoader")
model, need_process_weights_after_loading = self.revert_to_default(
model_config, vllm_config, device_config)
else:
target_device = torch.device(device_config.device)
vllm_config_backup = deepcopy(vllm_config)
model_config_backup = deepcopy(model_config)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
start_elastic_load = time.perf_counter()
model = elastic_load(
model=model,
device_id=device_id,
model_path=self.model_path,
sources=self.source,
tp=parallel_config.tensor_parallel_size,
pp=parallel_config.pipeline_parallel_size,
)
end_elastic_load = time.perf_counter()
logger.info(
f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}"
)
need_process_weights_after_loading = True
if model is None:
logger.warning(
"Netloader elastic loading fails, use load format DefaultModelLoader"
)
vllm_config = vllm_config_backup
model_config = model_config_backup
del model
gc.collect()
if device_config.device_type == 'npu':
logger.info("Empty NPU cache")
torch.npu.empty_cache()
elif device_config.device_type == 'cuda':
logger.info("Empty CUDA cache")
torch.cuda.empty_cache()
model, need_process_weights_after_loading = self.revert_to_default(
model_config, vllm_config, device_config)
start_elastic_server = time.perf_counter()
# start elastic server
if model is not None and (
(self.listen_port and self.listen_port in range(1024, 65535)) or
(self.listen_port is None)):
from vllm.utils import get_ip
driver_ip = get_ip()
if driver_ip == '0.0.0.0':
logger.error(
"Driver IP is not set, skip to start Netloader server")
else:
if self.listen_port is None:
self.listen_port = find_free_port()
else:
self.listen_port += device_id
logger.info(
f"Start elastic Netloader server, rank: {device_id}, listen port: {driver_ip}:{self.listen_port}"
)
if self.output_prefix is not None:
try:
with open(self.output_prefix + str(device_id) + '.txt',
'w') as file:
file.write(f"{driver_ip}:{self.listen_port}")
logger.info(
f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}"
)
except FileNotFoundError:
logger.error(
f"File path {self.output_prefix + str(device_id)} does not exist."
)
except PermissionError:
logger.error(
f"No permission to write to file {self.output_prefix + str(device_id)}."
)
except IOError as e:
logger.error(
f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
)
except Exception as e:
logger.error(f"Unknown error: {e}")
try:
assert isinstance(
self.listen_port, int
), f"listen port should be int but get {self.listen_port}"
elastic_server = ElasticServer(
driver_ip, self.listen_port, model, device_id,
self.model_path, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
self.int8_cache, self.int8_cache_name)
elastic_server.start()
except Exception as e:
logger.error(
f"Failed to start Netloader server for rank: {device_id}, details: {e}"
)
else:
logger.info("Skip to start Netloader server")
end_elastic_server = time.perf_counter()
logger.info(
f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}"
)
if need_process_weights_after_loading:
process_weights_after_loading(model, model_config,
torch.device(device_config.device))
if model is None:
logger.error("NetLoader elastic loads model fails")
return None
return model.eval()
def revert_to_default(self, model_config, vllm_config,
device_config) -> Tuple[nn.Module, bool]:
"""
Reverts to the default model loading logic when elastic loading fails or is not applicable.
This method resets the loader's extra config and load format to defaults,
then delegates model loading to a DefaultModelLoader.
If quantization is enabled, it will load the model and then run the
processing of weights (i.e. applying quantization adjustments) before returning.
Parameters:
- model_config: Configuration describing model architecture, quantization, etc.
- vllm_config: Configuration for vLLM (device, parallelism, dtype, etc).
- device_config: Configuration for the target device (device type, device id, etc).
Returns:
- A tuple (model, need_process_weights_after_loading):
* model: The loaded `nn.Module` under default loading logic.
* need_process_weights_after_loading: A boolean flag indicating whether
weights post-processing (e.g. quantization adjustments) still needs to be applied.
"""
self.load_config.model_loader_extra_config = {}
self.load_config.load_format = "auto"
default_model_loader = DefaultModelLoader(self.load_config)
if model_config.quantization is None:
model = default_model_loader.load_model(vllm_config=vllm_config,
model_config=model_config)
need_process_weights_after_loading = False
else:
logger.warning(
"Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading "
)
need_process_weights_after_loading = True
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
default_model_loader.load_weights(model, model_config)
model = model.eval()
return model, need_process_weights_after_loading
def download_model(self, model_config: ModelConfig) -> None:
pass
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
pass

View File

@@ -0,0 +1,66 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import socket
from vllm.logger import logger
def find_free_port():
"""
Finds a free port on the local machine.
Returns:
- A free port number.
"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
def is_valid_path_prefix(path_prefix):
"""
Checks if the provided path prefix is valid.
Parameters:
- path_prefix: The path prefix to validate.
Returns:
- True if the path prefix is valid, otherwise False.
"""
if not path_prefix:
return False
if re.search(r'[<>:"|?*]', path_prefix):
logger.warning(
f'The path prefix {path_prefix} contains illegal characters.')
return False
if path_prefix.startswith('/') or path_prefix.startswith('\\'):
if not os.path.exists(os.path.dirname(path_prefix)):
logger.warning(
f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.'
)
return False
else:
if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
logger.warning(
f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.'
)
return False
return True