diff --git a/docs/source/user_guide/feature_guide/images/netloader_flowchart.png b/docs/source/user_guide/feature_guide/images/netloader_flowchart.png
new file mode 100644
index 00000000..4efe6b2b
Binary files /dev/null and b/docs/source/user_guide/feature_guide/images/netloader_flowchart.png differ
diff --git a/docs/source/user_guide/feature_guide/images/netloader_timing_diagram.png b/docs/source/user_guide/feature_guide/images/netloader_timing_diagram.png
new file mode 100644
index 00000000..0af05812
Binary files /dev/null and b/docs/source/user_guide/feature_guide/images/netloader_timing_diagram.png differ
diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md
index 049e496f..61c333b0 100644
--- a/docs/source/user_guide/feature_guide/index.md
+++ b/docs/source/user_guide/feature_guide/index.md
@@ -11,5 +11,6 @@ sleep_mode
structured_output
lora
eplb_swift_balancer
+netloader
dynamic_batch
:::
diff --git a/docs/source/user_guide/feature_guide/netloader.md b/docs/source/user_guide/feature_guide/netloader.md
new file mode 100644
index 00000000..857f3473
--- /dev/null
+++ b/docs/source/user_guide/feature_guide/netloader.md
@@ -0,0 +1,97 @@
+# Netloader Guide
+
+This guide provides instructions for using **Netloader** as a weight-loader plugin for acceleration in **vLLM Ascend**.
+
+---
+
+## Overview
+
+Netloader leverages high-bandwidth peer-to-peer (P2P) transfers between NPU cards to load model weights. It is implemented as a plugin (via the `register_model_loader` API added in vLLM 0.10). The workflow is:
+
+1. A **server** preloads a model.
+2. A new **client** instance requests weight transfer.
+3. After validating that the model and partitioning match, the client uses HCCL collective communication (send/recv) to receive weights in the same order as stored in the model.
+
+The server runs alongside normal inference tasks via sub-threads and via `stateless_init_torch_distributed_process_group` in vLLM. The client thus takes over weight initialization without needing to load from storage.
+
+### Flowchart
+
+
+
+### Timing Diagram
+
+
+
+### Application Scenarios
+
+- **Reduce startup latency**: By reusing already loaded weights and transferring them directly between NPU cards, Netloader cuts down model loading time vs conventional remote/local pull strategies.
+- **Relieve network & storage load**: Avoid repeated downloads of weight files from remote repositories, thus reducing pressure on central storage and network traffic.
+- **Improve resource utilization & lower cost**: Faster loading allows less reliance on standby compute nodes; resources can be scaled up/down more flexibly.
+- **Enhance business continuity & high availability**: In failure recovery, new instances can quickly take over without long downtime, improving system reliability and user experience.
+
+---
+
+## Usage
+
+To enable Netloader, pass `--load-format=netloader` and provide configuration via `--model-loader-extra-config` (as a JSON string). Below are the supported configuration fields:
+
+| Field Name | Type | Description | Allowed Values / Notes |
+|--------------------|---------|------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------|
+| **SOURCE** | List | Weighted data sources. Each item is a map with `device_id` and `sources`, specifying the rank and its endpoints (IP:port).
Example: `{"SOURCE": [{"device_id": 0, "sources": ["10.170.22.152:19374"]}, {"device_id": 1, "sources": ["10.170.22.152:11228"]}]}`
If omitted or empty, fallback to default loader. The SOURCE here is second priority. | A list of objects with keys `device_id: int` and `sources: List[str]` |
+| **MODEL** | String | The model name, used to verify consistency between client and server. | Defaults to the `--model` argument if not specified. |
+| **LISTEN_PORT** | Integer | Base port for the server listener. | The actual port = `LISTEN_PORT + RANK`. If omitted, a random valid port is chosen. Valid range: 1024–65535. If out of range, that server instance won’t open a listener. |
+| **INT8_CACHE** | String | Behavior for handling int8 parameters in quantized models. | One of `["hbm", "dram", "no"]`.
- `hbm`: copy original int8 parameters to high-bandwidth memory (HBM) (may cost a lot of HBM).
- `dram`: copy to DRAM.
- `no`: no special handling (may lead to divergence or unpredictable behavior). Default: `"no"`. |
+| **INT8_CACHE_NAME** | List | Names of parameters to which `INT8_CACHE` is applied (i.e. filtering). | Default: `None` (means no filtering—all parameters). |
+| **OUTPUT_PREFIX** | String | Prefix for writing per-rank listener address/port files in server mode. | If set, each rank writes to `{OUTPUT_PREFIX}{RANK}.txt` (text), content = `IP:Port`. |
+| **CONFIG_FILE** | String | Path to a JSON file specifying the above configuration. | If provided, the SOURCE inside this file has **first priority** (overrides SOURCE in other configs). |
+
+---
+
+## Example Commands & Placeholders
+
+> Replace parts in `` `<...>` `` before running.
+
+### Server
+
+```shell
+VLLM_SLEEP_WHEN_IDLE=1 vllm serve `` \
+ --tensor-parallel-size 1 \
+ --served-model-name `` \
+ --enforce-eager \
+ --port `` \
+ --load-format netloader
+```
+
+### Client
+
+```shell
+export NETLOADER_CONFIG='{"SOURCE":[{"device_id":0, "sources": ["``:``"]}]}'
+
+VLLM_SLEEP_WHEN_IDLE=1 ASCEND_RT_VISIBLE_DEVICES=`` \
+ vllm serve `` \
+ --tensor-parallel-size 1 \
+ --served-model-name `` \
+ --enforce-eager \
+ --port `` \
+ --load-format netloader \
+ --model-loader-extra-config="${NETLOADER_CONFIG}"
+```
+
+#### Placeholder Descriptions
+
+- ``: Path to the model file
+- ``: Model name (must match between server & client)
+- ``: Base listening port on server
+- `` + ``: IP and port of the Netloader server (from server log)
+- ``: Client device ID (must differ from server’s)
+- ``: Port on which client listens
+
+After startup, you can test consistency by issuing inference requests with temperature = 0 and comparing outputs.
+
+---
+
+## Note & Caveats
+
+- If Netloader is used, **each worker process** must bind a listening port. That port may be user-specified or assigned randomly. If user-specified, ensure it is available.
+- Netloader requires extra HBM memory to establish HCCL connections (i.e. `HCCL_BUFFERSIZE`, default ~200 MB). Users should reserve sufficient capacity (e.g. via `--gpu-memory-utilization`).
+- It is recommended to set `VLLM_SLEEP_WHEN_IDLE=1` to mitigate unstable or slow connections/transmissions. Related info: [vLLM Issue #16660](https://github.com/vllm-project/vllm/issues/16660), [vLLM PR #16226](https://github.com/vllm-project/vllm/pull/16226).
diff --git a/setup.py b/setup.py
index 0795304b..5a823e7a 100644
--- a/setup.py
+++ b/setup.py
@@ -393,7 +393,8 @@ setup(
"vllm.platform_plugins": ["ascend = vllm_ascend:register"],
"vllm.general_plugins": [
"ascend_enhanced_model = vllm_ascend:register_model",
- "ascend_kv_connector = vllm_ascend:register_connector"
+ "ascend_kv_connector = vllm_ascend:register_connector",
+ "ascend_model_loader = vllm_ascend:register_model_loader"
],
},
)
diff --git a/tests/ut/model_loader/netloader/test_netloader.py b/tests/ut/model_loader/netloader/test_netloader.py
new file mode 100644
index 00000000..6658823a
--- /dev/null
+++ b/tests/ut/model_loader/netloader/test_netloader.py
@@ -0,0 +1,215 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+from torch import nn
+
+from vllm_ascend.model_loader.netloader.netloader import ModelNetLoaderElastic
+
+
+class DummyDeviceConfig:
+ device = 'cuda'
+ device_type = 'cuda'
+
+
+class DummyParallelConfig:
+ tensor_parallel_size = 1
+ pipeline_parallel_size = 1
+
+
+class DummyVllmConfig:
+ device_config = DummyDeviceConfig()
+ parallel_config = DummyParallelConfig()
+ additional_config = None
+
+
+class DummyModelConfig:
+ model = 'dummy-model'
+ dtype = torch.float32
+
+
+@pytest.fixture
+def default_load_config():
+
+ class DummyLoadConfig:
+ model_loader_extra_config = None
+ load_format = "default"
+
+ return DummyLoadConfig()
+
+
+def make_loader_with_config(extra):
+
+ class DummyLoadConfig:
+ model_loader_extra_config = extra
+ load_format = "default"
+
+ return ModelNetLoaderElastic(DummyLoadConfig())
+
+
+def test_init_with_extra_config_file(tmp_path, monkeypatch):
+ # Generate test JSON file
+ config_content = {
+ "SOURCE": [{
+ "device_id": 0
+ }],
+ "MODEL": "foo-model",
+ "LISTEN_PORT": 5001,
+ "INT8_CACHE": "hbm",
+ "OUTPUT_PREFIX": str(tmp_path),
+ }
+ config_file = tmp_path / "config.json"
+ config_file.write_text(json.dumps(config_content))
+
+ dummy_logger = MagicMock()
+ monkeypatch.setattr("vllm.logger.logger", dummy_logger)
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
+ lambda x: True)
+
+ extra = {"CONFIG_FILE": str(config_file)}
+ loader = make_loader_with_config(extra)
+ assert loader.model_path == "foo-model"
+ assert loader.source == [{"device_id": 0}]
+ assert loader.listen_port == 5001
+ assert loader.int8_cache == "hbm"
+ assert loader.output_prefix == str(tmp_path)
+
+
+def test_init_with_extra_config(monkeypatch):
+ dummy_logger = MagicMock()
+ monkeypatch.setattr("vllm.logger.logger", dummy_logger)
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
+ lambda x: True)
+
+ extra = {
+ "SOURCE": [{
+ "device_id": 0
+ }],
+ "MODEL": "foo",
+ "LISTEN_PORT": "4000",
+ "INT8_CACHE": "dram",
+ "OUTPUT_PREFIX": "/tmp/"
+ }
+ loader = make_loader_with_config(extra)
+ assert loader.model_path == "foo"
+ assert loader.listen_port == 4000
+ assert loader.int8_cache == "dram"
+ assert loader.output_prefix == "/tmp/"
+ assert loader.source == [{"device_id": 0}]
+
+
+def test_init_with_invalid_config(monkeypatch):
+ dummy_logger = MagicMock()
+ monkeypatch.setattr("vllm.logger.logger", dummy_logger)
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
+ lambda x: False)
+ # c
+ extra = {
+ "SOURCE": None,
+ "MODEL": None,
+ "LISTEN_PORT": None,
+ "INT8_CACHE": "something",
+ "OUTPUT_PREFIX": None,
+ }
+ loader = make_loader_with_config(extra)
+ assert loader.model_path is None
+ assert loader.listen_port is None
+ assert loader.int8_cache == "no"
+ assert loader.output_prefix is None
+
+
+@patch("vllm_ascend.model_loader.netloader.netloader.logger")
+def test_load_model_elastic_success(mock_logger, monkeypatch, tmp_path):
+ monkeypatch.setattr("torch.distributed.get_rank", lambda: 0)
+
+ class FakeContext:
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, a, b, c):
+ pass
+
+ monkeypatch.setattr("torch.device", lambda d: FakeContext())
+ # patch deep copy
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x)
+ # patch set_default_torch_dtype
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype",
+ lambda dtype: FakeContext())
+ # patch initialize_model
+ dummy_model = MagicMock(spec=nn.Module)
+ dummy_model.eval.return_value = dummy_model
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.initialize_model",
+ lambda **kwargs: dummy_model)
+ # patch elastic_load
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.elastic_load",
+ lambda **kwargs: dummy_model)
+ # patch process_weights_after_loading
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading",
+ lambda *a, **k: None)
+ # patch get_ip
+ monkeypatch.setattr("vllm.utils.get_ip", lambda: "127.0.0.1")
+ # patch find_free_port
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.find_free_port",
+ lambda: 8888)
+
+ # patch ElasticServer
+ class DummyElasticServer:
+
+ def __init__(*a, **k):
+ pass
+
+ def start(self):
+ pass
+
+ monkeypatch.setattr(
+ "vllm_ascend.model_loader.netloader.netloader.ElasticServer",
+ DummyElasticServer)
+ # write output_prefix to the temporary directory
+ extra = {
+ "SOURCE": [{
+ "device_id": 0
+ }],
+ "MODEL": "foo",
+ "LISTEN_PORT": 5555,
+ "OUTPUT_PREFIX": str(tmp_path) + "/output_",
+ "INT8_CACHE": "no"
+ }
+ loader = make_loader_with_config(extra)
+ vllm_config = DummyVllmConfig()
+ model_config = DummyModelConfig()
+ result = loader.load_model(vllm_config, model_config)
+ assert isinstance(result, nn.Module)
+ # Check file
+ written_file = tmp_path / "output_0.txt"
+ assert written_file.exists()
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/ut/model_loader/netloader/test_netloader_elastic.py b/tests/ut/model_loader/netloader/test_netloader_elastic.py
new file mode 100644
index 00000000..127f1dd6
--- /dev/null
+++ b/tests/ut/model_loader/netloader/test_netloader_elastic.py
@@ -0,0 +1,432 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import io
+import json
+import logging
+import socket
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+import vllm.logger
+
+from vllm_ascend.model_loader.netloader.interaction.elastic import (
+ ElasticClient, ElasticServer)
+
+
+# Simulate server's normal response
+def mock_server_response(data):
+ return json.dumps({
+ "label": "JOIN_ACK",
+ "content": {
+ "name": "mocked_name"
+ }
+ }).encode("utf-8")
+
+
+# Simulate server's error response
+def mock_server_error_response(data):
+ return json.dumps({"label": "JOIN_ACK", "content": None}).encode("utf-8")
+
+
+# Simulated server's abnormal response
+def mock_server_exception_response(data):
+ raise Exception("Mocked server exception")
+
+
+# Test the initialization of ElasticClient
+def test_elastic_client_init():
+ sources = ["127.0.0.1:12345"]
+ device_id = 0
+ model_path = "mocked_model_path"
+ tp = 1
+ pp = 1
+
+ with patch('socket.socket') as mock_socket:
+ mock_socket_instance = MagicMock()
+ mock_socket.return_value = mock_socket_instance
+ mock_socket_instance.recv.return_value = mock_server_response(None)
+
+ mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
+ mock_socket_instance.__enter__.return_value = mock_socket_instance
+
+ with ElasticClient(sources, device_id, model_path, tp, pp) as client:
+ assert client.server_addr == "127.0.0.1"
+ assert client.server_port == 12345
+ assert client.ack == ("mocked_name", 12346)
+ mock_socket_instance.close.assert_called_once()
+
+
+# Test the register method of ElasticClient
+def test_elastic_client_register():
+ sources = ["127.0.0.1:12345"]
+ device_id = 0
+ model_path = "mocked_model_path"
+ tp = 1
+ pp = 1
+
+ with patch('socket.socket') as mock_socket:
+ mock_socket_instance = MagicMock()
+ mock_socket.return_value = mock_socket_instance
+ mock_socket_instance.connect.return_value = None
+ mock_socket_instance.recv.return_value = mock_server_response(None)
+
+ mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
+ mock_socket_instance.__enter__.return_value = mock_socket_instance
+
+ client = ElasticClient(sources, device_id, model_path, tp, pp)
+ assert client.register(device_id, model_path, tp,
+ pp) == ("mocked_name", 12346)
+
+
+# Test the behavior of the `register` method of ElasticClient when the server returns an error response.
+def test_elastic_client_register_error_response():
+ sources = ["127.0.0.1:12345"]
+ device_id = 0
+ model_path = "mocked_model_path"
+ tp = 1
+ pp = 1
+
+ with patch('socket.socket') as mock_socket:
+ mock_socket_instance = MagicMock()
+ mock_socket.return_value = mock_socket_instance
+ mock_socket_instance.connect.return_value = None
+ mock_socket_instance.recv.return_value = mock_server_error_response(
+ None)
+
+ with ElasticClient(sources, device_id, model_path, tp, pp) as client:
+ with pytest.raises(RuntimeError):
+ client.register(device_id, model_path, tp, pp)
+ mock_socket_instance.close.assert_called_once()
+
+
+# Test the behavior of the `register` method of ElasticClient when an exception is thrown on the server.
+def test_elastic_client_register_exception():
+ sources = ["127.0.0.1:12345"]
+ device_id = 0
+ model_path = "mocked_model_path"
+ tp = 1
+ pp = 1
+
+ with patch('socket.socket') as mock_socket:
+ mock_socket_instance = MagicMock()
+ mock_socket.return_value = mock_socket_instance
+ mock_socket_instance.connect.return_value = None
+ mock_socket_instance.recv.side_effect = mock_server_exception_response
+ mock_socket_instance.__enter__.return_value = mock_socket_instance
+ mock_socket_instance.__exit__.return_value = None
+
+ with ElasticClient(sources, device_id, model_path, tp, pp) as client:
+ with pytest.raises(RuntimeError):
+ client.register(device_id, model_path, tp, pp)
+ mock_socket_instance.close.assert_called_once()
+
+
+class FakeInt8Param:
+
+ def __init__(self, name="param", device="npu", dtype=torch.int8):
+ self.dtype = dtype
+ self.device = torch.device(device)
+
+ @property
+ def data(self):
+ return self # Simulate .data returning self so .cpu() etc. can be chained
+
+ def clone(self):
+ return self
+
+ def detach(self):
+ return self
+
+ def cpu(self):
+ self.device = torch.device("cpu")
+ return self
+
+
+class FakeModel:
+
+ def __init__(self):
+ self.params = {
+ "param1": MagicMock(dtype=torch.float32), # This will be ignored
+ "param2": FakeInt8Param(), # This simulates a real int8 param
+ }
+
+ def named_parameters(self):
+ return self.params.items()
+
+
+@pytest.fixture
+def mock_model():
+ return FakeModel()
+
+
+@pytest.fixture
+def server_config():
+ return {
+ "addr": "127.0.0.1",
+ "port": 8080,
+ "model": MagicMock(),
+ "device_id": 0,
+ "model_path": "/test/model",
+ "tp": 1,
+ "pp": 1,
+ "int8_cache": "dram",
+ 'int8_cache_name': None
+ }
+
+
+# Test server initialization
+def test_server_initialization(server_config, mock_model):
+ server_config["model"] = mock_model
+ with patch("socket.socket") as mock_socket:
+ log_capture_string = io.StringIO()
+ ch = logging.StreamHandler(log_capture_string)
+ ch.setLevel(logging.DEBUG)
+ vllm.logger.logger.addHandler(ch)
+
+ server = ElasticServer(**server_config)
+
+ # Check the socket configuration
+ mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM)
+ mock_socket.return_value.setsockopt.assert_called_with(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080))
+ mock_socket.return_value.listen.assert_called_with(256)
+
+ # Check int8 cache
+ assert "param2" in server.original_int8
+ assert server.original_int8[
+ "param2"].device.type == "cpu" # Verifying DRAM Cache
+
+ assert server.addr == server_config['addr']
+ assert server.port == server_config['port']
+ assert server.device_id == server_config['device_id']
+ assert server.model_path == server_config['model_path']
+ assert server.tp == server_config['tp']
+ assert server.pp == server_config['pp']
+
+ # Get captured logs
+ log_output = log_capture_string.getvalue()
+ vllm.logger.logger.removeHandler(ch)
+ log_capture_string.close()
+
+ # Check output
+ assert "Server 127.0.0.1:8080 starts" in log_output
+
+
+# Test the int8 cache option
+@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"),
+ ("no", None),
+ ("invalid", None)])
+def test_int8_cache_handling(server_config, mock_model, cache_option,
+ expected_device, caplog):
+ server_config["int8_cache"] = cache_option
+ server_config["model"] = mock_model
+
+ with patch("socket.socket"):
+ log_capture_string = io.StringIO()
+ ch = logging.StreamHandler(log_capture_string)
+ ch.setLevel(logging.DEBUG)
+ vllm.logger.logger.addHandler(ch)
+
+ server = ElasticServer(**server_config)
+
+ log_output = log_capture_string.getvalue()
+ vllm.logger.logger.removeHandler(ch)
+ log_capture_string.close()
+
+ if cache_option == "invalid":
+ assert "int8_cache should be selected in [HBM, DRAM]" in log_output
+
+ if expected_device is None:
+ assert len(server.original_int8) == 0
+ else:
+ assert server.original_int8[
+ "param2"].device.type == expected_device
+
+
+# Test client processing
+def test_client_handler_valid_join(server_config, mock_model):
+ server_config["model"] = mock_model
+ with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend"
+ ) as mock_p2p_send:
+
+ # Create a simulated connection
+ mock_conn = MagicMock()
+ mock_addr = ("192.168.1.1", 12345)
+
+ # Configuring Client Data
+ valid_data = {
+ "label": "JOIN",
+ "content": {
+ "device_id": 0,
+ "model_path": "/test/model",
+ "tp": 1,
+ "pp": 1,
+ "port": 9090
+ }
+ }
+ mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8")
+
+ # Start the server
+ server = ElasticServer(**server_config)
+ server.register_handler(mock_conn, mock_addr)
+
+ # Verify response
+ expected_ack = {
+ "label": "JOIN_ACK",
+ "content": {
+ "name": "192.168.1.1:12345"
+ }
+ }
+ mock_conn.send.assert_called_once_with(
+ json.dumps(expected_ack).encode("utf-8"))
+ mock_p2p_send.assert_called_once_with("127.0.0.1", 9090,
+ "192.168.1.1:12345")
+ mock_conn.close.assert_called_once()
+
+
+# Test mismatched JOIN requests
+def test_client_handler_mismatch(server_config):
+ with patch("socket.socket"):
+ server = ElasticServer(**server_config)
+ mock_conn = MagicMock()
+ mock_addr = ("192.168.1.1", 12345)
+
+ # Send mismatched data
+ mismatch_data = {
+ "label": "JOIN",
+ "content": {
+ "device_id": 1, # 不匹配的ID
+ "model_path": "/wrong/model",
+ "tp": 2,
+ "pp": 2,
+ "port": 9090
+ }
+ }
+ mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8")
+
+ server.register_handler(mock_conn, mock_addr)
+
+ assert isinstance(mismatch_data["content"], dict)
+
+ # Verify response
+ expected_ack = {
+ "label":
+ "JOIN_NACK",
+ "content":
+ f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}"
+ }
+ mock_conn.send.assert_called_once_with(
+ json.dumps(expected_ack).encode("utf-8"))
+ mock_conn.close.assert_called_once()
+
+
+# Test Invalid Request
+@pytest.mark.parametrize(
+ "invalid_data,should_send",
+ [
+ (
+ {
+ "label": "WRONG_LABEL"
+ }, True
+ ), # Incorrect label, can be decoded as JSON, but the content is invalid.
+ (
+ {
+ "content": {
+ "missing_fields": True
+ }
+ }, True
+ ), # Missing field, can be decoded as JSON, but the content is invalid.
+ ("plain text", False), # Non-JSON data, json.loads failed
+ (b"invalid_bytes", False) # Invalid byte, decode or json.loads failed
+ ])
+def test_client_handler_invalid_requests(server_config, invalid_data,
+ should_send):
+ with patch("socket.socket"):
+ log_capture_string = io.StringIO()
+ ch = logging.StreamHandler(log_capture_string)
+ ch.setLevel(logging.DEBUG)
+ vllm.logger.logger.addHandler(ch)
+
+ with patch("socket.socket"):
+ server = ElasticServer(**server_config)
+ mock_conn = MagicMock()
+ mock_addr = ("192.168.1.1", 12345)
+
+ if isinstance(invalid_data, (str, bytes)):
+ mock_conn.recv.return_value = invalid_data if isinstance(
+ invalid_data, bytes) else invalid_data.encode()
+ else:
+ mock_conn.recv.return_value = json.dumps(invalid_data).encode(
+ "utf-8")
+
+ server.register_handler(mock_conn, mock_addr)
+
+ if should_send:
+ expected_ack = {
+ "label":
+ "JOIN_NACK",
+ "content":
+ f"Received data does not contain required fields: {invalid_data}"
+ }
+ mock_conn.send.assert_called_once_with(
+ json.dumps(expected_ack).encode("utf-8"))
+ else:
+ mock_conn.send.assert_not_called()
+
+ log_output = log_capture_string.getvalue()
+ vllm.logger.logger.removeHandler(ch)
+ log_capture_string.close()
+
+ # Any warning in the log is acceptable
+ assert "Failed to load" in log_output or "does not contain" in log_output
+ mock_conn.close.assert_called_once()
+
+
+# Test the thread startup.
+def test_server_start(server_config):
+ with patch("socket.socket"), \
+ patch("threading.Thread") as mock_thread:
+
+ handler_thread_instance = mock_thread.return_value
+
+ server = ElasticServer(**server_config)
+ server.start()
+
+ # Assert that the correct target parameter was passed when instantiating the Thread instance.
+ mock_thread.assert_called_once()
+ args, kwargs = mock_thread.call_args
+ assert kwargs['target'] == server.elastic_client_handler
+
+ # Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment).
+ assert handler_thread_instance.daemon is True
+
+ # Check if the start() method is called.
+ handler_thread_instance.start.assert_called_once()
+
+
+# Test resource clearing
+def test_server_cleanup(server_config):
+ with patch("socket.socket") as mock_socket:
+ server = ElasticServer(**server_config)
+ del server
+ mock_socket.return_value.close.assert_called_once()
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/ut/model_loader/netloader/test_netloader_load.py b/tests/ut/model_loader/netloader/test_netloader_load.py
new file mode 100644
index 00000000..77c3486f
--- /dev/null
+++ b/tests/ut/model_loader/netloader/test_netloader_load.py
@@ -0,0 +1,114 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from vllm_ascend.model_loader.netloader.load import elastic_load
+
+
+@pytest.fixture
+def mock_sources():
+ return [
+ {
+ "device_id": 0,
+ "sources": ["a", "b"]
+ },
+ {
+ "device_id": 1,
+ "sources": ["c"]
+ },
+ ]
+
+
+@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
+@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
+def test_sources_this_device_empty(mock_p2p, mock_client):
+ sources = [{"device_id": 1, "sources": ["c"]}]
+ result = elastic_load("model", 0, "model_path", sources, 1, 1)
+ assert result is None
+ mock_client.assert_not_called()
+ mock_p2p.assert_not_called()
+
+
+@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
+@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
+def test_client_s_none(mock_p2p, mock_client, mock_sources):
+ # Simulate ElasticClient.s as None
+ mock_instance = MagicMock()
+ mock_instance.s = None
+ mock_client.return_value = mock_instance
+ result = elastic_load("model", 0, "model_path", mock_sources, 1, 1)
+ assert result is None
+
+
+@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
+@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
+def test_client_ack_none(mock_p2p, mock_client, mock_sources):
+ # Simulate ElasticClient.ack as None
+ mock_instance = MagicMock()
+ mock_instance.s = True
+ mock_instance.ack = None
+ mock_client.return_value = mock_instance
+ result = elastic_load("model", 0, "model_path", mock_sources, 1, 1)
+ assert result is None
+
+
+@patch("vllm_ascend.model_loader.netloader.load.P2PLoad")
+@patch("vllm_ascend.model_loader.netloader.load.logger")
+def test_model_load_fail(mock_logger, mock_p2p):
+ mock_client = MagicMock()
+ mock_client.s = True
+ mock_client.ack = ["foo", "bar"]
+ mock_client.server_addr = "addr"
+
+ with patch("vllm_ascend.model_loader.netloader.load.ElasticClient",
+ return_value=mock_client):
+ # P2PLoad.load returns None
+ mock_p2p_instance = MagicMock()
+ mock_p2p_instance.load.return_value = None
+ mock_p2p.return_value = mock_p2p_instance
+
+ sources = [{"device_id": 0, "sources": ["whatever"]}]
+ result = elastic_load("model", 0, "model_path", sources, 1, 1)
+ assert result is None
+ mock_logger.error.assert_called_once()
+
+
+@patch("vllm_ascend.model_loader.netloader.load.P2PLoad")
+@patch("vllm_ascend.model_loader.netloader.load.logger")
+def test_model_load_success(mock_logger, mock_p2p):
+ mock_client = MagicMock()
+ mock_client.s = True
+ mock_client.ack = ["foo", "bar"]
+ mock_client.server_addr = "addr"
+
+ with patch("vllm_ascend.model_loader.netloader.load.ElasticClient",
+ return_value=mock_client):
+ expected_model = object()
+ mock_p2p_instance = MagicMock()
+ mock_p2p_instance.load.return_value = expected_model
+ mock_p2p.return_value = mock_p2p_instance
+
+ sources = [{"device_id": 0, "sources": ["whatever"]}]
+ result = elastic_load("model", 0, "model_path", sources, 1, 1)
+ assert result is expected_model
+ mock_logger.info.assert_called_once()
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/tests/ut/model_loader/netloader/test_netloader_utils.py b/tests/ut/model_loader/netloader/test_netloader_utils.py
new file mode 100644
index 00000000..66198449
--- /dev/null
+++ b/tests/ut/model_loader/netloader/test_netloader_utils.py
@@ -0,0 +1,61 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import tempfile
+from unittest.mock import patch
+
+import pytest
+
+from vllm_ascend.model_loader.netloader.utils import (find_free_port,
+ is_valid_path_prefix)
+
+
+def test_find_free_port():
+ port = find_free_port()
+ assert isinstance(port, int)
+ assert port > 0
+
+
+def test_is_valid_path_prefix_empty():
+ assert not is_valid_path_prefix('')
+
+
+def test_is_valid_path_prefixIllegal_characters():
+ assert not is_valid_path_prefix('test<>:"|?*')
+
+
+def test_is_valid_path_prefixRelative_path():
+ assert is_valid_path_prefix('test')
+
+
+def test_is_valid_path_prefixAbsolute_path():
+ with tempfile.TemporaryDirectory() as tmpdir:
+ assert is_valid_path_prefix(os.path.join(tmpdir, 'test'))
+
+
+@patch('os.path.exists', return_value=False)
+def test_is_valid_path_prefix_no_directory(mock_exists):
+ assert not is_valid_path_prefix('/nonexistent_dir/test')
+
+
+@patch('os.path.exists', return_value=True)
+def test_is_valid_path_prefix_directory_exists(mock_exists):
+ assert is_valid_path_prefix('/existing_dir/test')
+
+
+if __name__ == "__main__":
+ pytest.main()
diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py
index 74a81537..aa72f74b 100644
--- a/vllm_ascend/__init__.py
+++ b/vllm_ascend/__init__.py
@@ -31,3 +31,8 @@ def register_model():
def register_connector():
from vllm_ascend.distributed import register_connector
register_connector()
+
+
+def register_model_loader():
+ from .model_loader.netloader import register_netloader
+ register_netloader()
\ No newline at end of file
diff --git a/vllm_ascend/model_loader/__init__.py b/vllm_ascend/model_loader/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/vllm_ascend/model_loader/netloader/__init__.py b/vllm_ascend/model_loader/netloader/__init__.py
new file mode 100644
index 00000000..f3b7c50a
--- /dev/null
+++ b/vllm_ascend/model_loader/netloader/__init__.py
@@ -0,0 +1,20 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+def register_netloader():
+ """Register the NetLoader plugin."""
+ from .netloader import ModelNetLoaderElastic # noqa
diff --git a/vllm_ascend/model_loader/netloader/executor/__init__.py b/vllm_ascend/model_loader/netloader/executor/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/vllm_ascend/model_loader/netloader/executor/elastic_load.py b/vllm_ascend/model_loader/netloader/executor/elastic_load.py
new file mode 100644
index 00000000..850bfaf9
--- /dev/null
+++ b/vllm_ascend/model_loader/netloader/executor/elastic_load.py
@@ -0,0 +1,170 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import torch
+import torch_npu
+from vllm.distributed.utils import (
+ stateless_destroy_torch_distributed_process_group,
+ stateless_init_torch_distributed_process_group)
+from vllm.logger import logger
+
+
+class P2PLoad:
+ """
+ Class for receiving model parameters in a distributed manner using HCCL backend.
+ """
+
+ def __init__(
+ self,
+ world_name: str,
+ source_ip: str,
+ source_port: int,
+ ):
+ """
+ Initializes the P2PLoad instance.
+
+ Parameters:
+ - world_name: The name of the distributed group.
+ - source_ip: The IP address of the source node.
+ - source_port: The port number for the source node.
+ """
+ self.world_name = world_name
+ self.source_ip = source_ip
+ self.source_port = source_port
+
+ def load(self, model):
+ """
+ Loads the model parameters using HCCL backend.
+
+ Parameters:
+ - model: The model whose parameters are to be loaded.
+
+ Returns:
+ - The model if loading is successful, otherwise None.
+ """
+ model_device = next(model.parameters()).device
+ logger.info(
+ f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
+ )
+ receiver_pg = None
+ loaded_model = None
+ try:
+ receiver_pg = stateless_init_torch_distributed_process_group(
+ host=self.world_name.split(":")[0],
+ port=self.source_port,
+ rank=0,
+ world_size=2,
+ backend='hccl',
+ )
+ logger.info(
+ f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
+ )
+
+ logger.info(
+ f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
+ )
+ logger.info(f"Model device: {model_device}")
+
+ trans_stream = torch_npu.npu.Stream()
+ with torch_npu.npu.stream(trans_stream):
+ for name, param in model.named_parameters():
+ if len(param.shape) == 0:
+ continue
+ receiver_pg.recv([param], 1, 0).wait()
+ torch.distributed.barrier(group=receiver_pg,
+ device_ids=[model_device.index])
+
+ torch_npu.npu.synchronize(trans_stream)
+
+ logger.info(
+ f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}"
+ )
+ loaded_model = model
+ except Exception as e:
+ logger.error("Failed to recv model: {}".format(e))
+ finally:
+ if receiver_pg:
+ stateless_destroy_torch_distributed_process_group(receiver_pg)
+ return loaded_model
+
+
+class P2PSend:
+ """
+ Class for sending model parameters in a distributed manner using HCCL backend.
+ """
+
+ def __init__(self, listen_ip: str, listen_port: int, comm_name: str):
+ """
+ Initializes the P2PSend instance.
+
+ Parameters:
+ - listen_ip: The IP address to listen on.
+ - listen_port: The port number to listen on.
+ - comm_name: The name of the communication group.
+ """
+ self.listen_ip = listen_ip
+ self.listen_port = listen_port
+ self.comm_name = comm_name
+
+ def send(self, model, int8_params: dict):
+ """
+ Sends the model parameters using HCCL backend.
+
+ Parameters:
+ - model: The model whose parameters are to be sent.
+ - int8_params: Dictionary of parameters that are in int8 format.
+ """
+ model_device = next(model.parameters()).device
+ torch.npu.set_device(model_device)
+ logger.info(
+ f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
+ )
+ sender_pg = None
+ try:
+ sender_pg = stateless_init_torch_distributed_process_group(
+ host=self.comm_name.split(":")[0],
+ port=self.listen_port,
+ rank=1,
+ world_size=2,
+ backend='hccl',
+ )
+ logger.info(
+ f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
+ )
+ logger.info(
+ f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
+ )
+ logger.info(f"Model device: {model_device}")
+
+ trans_stream = torch_npu.npu.Stream()
+ with torch_npu.npu.stream(trans_stream):
+ for name, param in model.named_parameters():
+ if "aclnn_input_scale" in name:
+ continue
+ if name in int8_params:
+ sender_pg.send([int8_params[name].to(model_device)], 0,
+ 0).wait()
+ else:
+ sender_pg.send([param.contiguous()], 0, 0).wait()
+ torch.distributed.barrier(group=sender_pg,
+ device_ids=[model_device.index])
+ torch_npu.npu.synchronize(trans_stream)
+ logger.info(
+ f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}"
+ )
+ finally:
+ if sender_pg:
+ stateless_destroy_torch_distributed_process_group(sender_pg)
diff --git a/vllm_ascend/model_loader/netloader/interaction/__init__.py b/vllm_ascend/model_loader/netloader/interaction/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/vllm_ascend/model_loader/netloader/interaction/elastic.py b/vllm_ascend/model_loader/netloader/interaction/elastic.py
new file mode 100644
index 00000000..1000bd7c
--- /dev/null
+++ b/vllm_ascend/model_loader/netloader/interaction/elastic.py
@@ -0,0 +1,408 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import re
+import socket
+import threading
+from typing import List, Optional, Tuple
+
+import torch
+from vllm.logger import logger
+
+from ..executor.elastic_load import P2PSend
+from ..utils import find_free_port
+
+
+class ElasticClient:
+ """
+ Class for handling the client-side logic of Netloader of models.
+ """
+
+ def __init__(self, sources: list[str], device_id: int, model_path: str,
+ tp: int, pp: int):
+ """
+ Initializes the ElasticClient instance.
+
+ Parameters:
+ - sources: List of source addresses in the format IP:port.
+ - device_id: The ID of the current device.
+ - model_path: The path to the model.
+ - tp: Tensor parallel size.
+ - pp: Pipeline parallel size.
+ """
+ self.sources = sources
+ self.device_id = device_id
+ self.model_path = model_path
+ self.tp = tp
+ self.pp = pp
+
+ self.s: Optional[socket.socket] = None
+ self.ack: Optional[Tuple[str, int]] = None
+ self.server_addr: Optional[str] = None
+ self.server_port: Optional[int] = None
+
+ for source in self.sources:
+ try:
+ ip, port_str = source.split(':')
+ port = int(port_str)
+ except Exception as e:
+ logger.error(f"IP format error: {source}, detail: {e}")
+ continue
+
+ self.server_addr = ip
+ self.server_port = port
+
+ try:
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ logger.info(
+ f"Start connection to server: {self.server_addr}:{self.server_port}"
+ )
+ sock.connect((self.server_addr, self.server_port))
+ logger.info(
+ f"Finish connection to server: {self.server_addr}:{self.server_port}"
+ )
+ sock.settimeout(60)
+
+ self.s = sock
+ self.ack = self.register(device_id, model_path, tp, pp)
+ break
+ except Exception as e:
+ logger.error(f"Connect to {source} fails, detail: {e}")
+ if sock is not None:
+ try:
+ sock.close()
+ except Exception:
+ pass
+ self.s = None
+ self.ack = None
+ self.server_addr = None
+ self.server_port = None
+
+ def close(self) -> None:
+ """
+ Closes the socket connection.
+ """
+ if self.s is not None:
+ try:
+ self.s.close()
+ except Exception as e:
+ logger.error(f"Error closing socket: {e}")
+ finally:
+ self.s = None
+
+ def __enter__(self) -> "ElasticClient":
+ """
+ Context manager enter method.
+ """
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
+ """
+ Context manager exit method.
+ """
+ self.close()
+
+ def __del__(self):
+ """
+ Destructor method to ensure socket is closed.
+ """
+ try:
+ self.close()
+ except Exception:
+ pass
+
+ def send_str(self, data_str: str) -> None:
+ """
+ Sends a string over the socket connection.
+
+ Parameters:
+ - data_str: The string to be sent.
+ """
+ if self.s is None:
+ raise RuntimeError("Socket was not created correctly.")
+ self.s.send(data_str.encode("utf-8"))
+
+ def recv_str(self, buffer_size: int = 1024) -> str:
+ """
+ Receives a string over the socket connection.
+
+ Parameters:
+ - buffer_size: The size of the buffer for receiving data.
+
+ Returns:
+ - The received string.
+ """
+ if self.s is None:
+ raise RuntimeError("Socket was not created correctly.")
+ data_str = self.s.recv(buffer_size).decode("utf-8")
+ return data_str
+
+ def register(self, device_id: int, model_path: str, tp: int,
+ pp: int) -> Tuple[str, int]:
+ """
+ Registers the client with the server.
+
+ Parameters:
+ - device_id: The ID of the current device.
+ - model_path: The path to the model.
+ - tp: Tensor parallel size.
+ - pp: Pipeline parallel size.
+
+ Returns:
+ - A tuple containing the communication name and port.
+ """
+ free_port = find_free_port()
+ data = {
+ "label": "JOIN",
+ "content": {
+ 'device_id': device_id,
+ 'model_path': model_path,
+ 'tp': tp,
+ 'pp': pp,
+ 'port': free_port
+ }
+ }
+
+ try:
+ self.send_str(json.dumps(data))
+ except Exception as e:
+ raise RuntimeError(
+ f"Send data {data} to server fails, detail: {e}")
+
+ try:
+ ack_str = self.recv_str()
+ except Exception as e:
+ raise RuntimeError(f"Receive data from server fails, detail: {e}")
+
+ try:
+ ack = json.loads(ack_str)
+ except Exception as e:
+ raise RuntimeError(
+ f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}"
+ )
+
+ logger.info(f"Receive ack: {ack}")
+
+ if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack
+ and ack["content"] is not None and "name" in ack["content"]):
+ return (ack["content"]["name"], free_port)
+ elif ("label" in ack and ack["label"] == 'JOIN_NACK'
+ and "content" in ack):
+ raise RuntimeError(
+ f"Receive nack from server, reason: {ack['content']}")
+ else:
+ raise RuntimeError(
+ f"Receive ack {ack} from server does not contain required fields"
+ )
+
+
+class ElasticServer:
+ """
+ Class for handling the server-side logic of Netloader of models.
+ """
+
+ def __init__(self, addr: str, port: int, model, device_id: int,
+ model_path: str, tp: int, pp: int, int8_cache: str,
+ int8_cache_name: Optional[List[str]]):
+ """
+ Initializes the ElasticServer instance.
+
+ Parameters:
+ - addr: The IP address to listen on.
+ - port: The port number to listen on.
+ - model: The model to be served.
+ - device_id: The ID of the current device (i.e. global rank).
+ - model_path: The path to the model.
+ - tp: Tensor parallel size.
+ - pp: Pipeline parallel size.
+ - int8_cache: The type of caching for int8 parameters (HBM, DRAM, or no).
+ - int8_cache_name: List of parameter names to be cached.
+ """
+ self.addr = addr
+ self.port = port
+ self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ self.s.bind((self.addr, self.port))
+ self.s.listen(256)
+
+ self.model = model
+ self.device_id = device_id
+ self.model_path = model_path
+ self.tp = tp
+ self.pp = pp
+
+ self.original_int8 = {}
+ int8_pattern = "|".join(
+ map(re.escape,
+ int8_cache_name)) if int8_cache_name is not None else "(?:)"
+ for name, param in self.model.named_parameters():
+ if param.dtype == torch.int8:
+ if int8_cache == 'hbm':
+ if int8_cache_name is None or (
+ int8_cache_name is not None
+ and re.search(int8_pattern, name) is not None):
+ try:
+ self.original_int8[name] = param.data.clone(
+ ).detach()
+ except RuntimeError as e:
+ logger.error(
+ f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}"
+ )
+ self.original_int8[name] = param.data.cpu()
+
+ elif int8_cache == 'dram':
+ if int8_cache_name is None or (
+ int8_cache_name is not None
+ and re.search(int8_pattern, name) is not None):
+ self.original_int8[name] = param.data.cpu()
+ elif int8_cache == 'no':
+ pass
+ else:
+ logger.warning(
+ f"int8_cache should be selected in [HBM, DRAM], but got {int8_cache}, change to no cache"
+ )
+
+ logger.info(
+ f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}"
+ )
+
+ def __del__(self):
+ """
+ Destructor method to ensure socket is closed.
+ """
+ self.s.close()
+
+ def start(self):
+ """
+ Starts the server to handle incoming connections.
+ """
+ handler_thread = threading.Thread(target=self.elastic_client_handler)
+ handler_thread.daemon = True
+ handler_thread.start()
+
+ def elastic_client_handler(self):
+ """
+ Handles incoming client connections.
+ """
+ while True:
+ conn, addr = self.s.accept()
+ logger.info("Accept new connection from {}:{}...".format(*addr))
+ self.register_handler(conn, addr)
+
+ def register_handler(self, conn, addr, buffer_size=1024):
+ """
+ Handles the registration of a client.
+
+ Parameters:
+ - conn: The connection socket.
+ - addr: The address of the client.
+ - buffer_size: The size of the buffer for receiving data.
+ """
+ data_str = conn.recv(buffer_size).decode("utf-8")
+ if not data_str:
+ return
+ try:
+ data = json.loads(data_str)
+ except Exception:
+ logger.error(f"Failed to load {data_str} as JSON string")
+ conn.close()
+ return
+
+ def is_valid_data(data):
+ """
+ Validates the received data.
+
+ Parameters:
+ - data: The data to be validated.
+
+ Returns:
+ - True if the data is valid, otherwise False.
+ """
+ if not isinstance(data, dict):
+ return False
+ if data.get("label") != "JOIN":
+ return False
+ content = data.get("content")
+ if not isinstance(content, dict):
+ return False
+ required_keys = ["device_id", "model_path", "tp", "pp", "port"]
+ if not all(k in content for k in required_keys):
+ return False
+ port = content["port"]
+ if not (isinstance(port, int) or
+ (isinstance(port, str) and port.isdigit())):
+ return False
+ return True
+
+ comm_name = None
+ if is_valid_data(data):
+ device_id = int(data["content"]["device_id"])
+ model_path = data["content"]["model_path"]
+ tp = int(data["content"]["tp"])
+ pp = int(data["content"]["pp"])
+
+ if int(self.device_id
+ ) == device_id and self.model_path == model_path and int(
+ self.tp) == tp and int(self.pp) == pp:
+ comm_name = str(addr[0]) + ":" + str(addr[1])
+ ack = {"label": "JOIN_ACK", "content": {"name": comm_name}}
+ else:
+ logger.warning(
+ f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
+ )
+ ack = {
+ "label":
+ "JOIN_NACK",
+ "content":
+ f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}"
+ }
+ else:
+ logger.warning(
+ f"Received data does not contain required fields: {data}")
+ ack = {
+ "label":
+ "JOIN_NACK",
+ "content":
+ f"Received data does not contain required fields: {data}"
+ }
+
+ try:
+ ack_str = json.dumps(ack).encode("utf-8")
+ except Exception as e:
+ logger.error(
+ f"Failed to convert {ack} to JSON format, details: {e}")
+ conn.close()
+ return
+
+ try:
+ conn.send(ack_str)
+ except Exception as e:
+ logger.error(f"Failed to send {ack} to {addr}, details: {e}")
+ conn.close()
+ return
+
+ if ack["content"] and isinstance(ack["content"],
+ dict) and 'name' in ack["content"]:
+ try:
+ p2psend = P2PSend(self.addr, data["content"]["port"],
+ ack["content"]["name"])
+ p2psend.send(self.model, self.original_int8)
+ except Exception as e:
+ logger.error(
+ f"P2PSend Failed to send model to {self.addr}, details: {e}"
+ )
+ conn.close()
diff --git a/vllm_ascend/model_loader/netloader/load.py b/vllm_ascend/model_loader/netloader/load.py
new file mode 100644
index 00000000..90000d58
--- /dev/null
+++ b/vllm_ascend/model_loader/netloader/load.py
@@ -0,0 +1,84 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+
+from vllm.logger import logger
+
+from .executor.elastic_load import P2PLoad
+from .interaction.elastic import ElasticClient
+
+
+def elastic_load(
+ model,
+ device_id: int,
+ model_path: str,
+ sources: list,
+ tp: int,
+ pp: int,
+):
+ """
+ Loads a model using elastic loading across multiple devices.
+
+ Parameters:
+ - model: The model instance to be loaded.
+ - device_id: The ID of the current device (i.e. global rank).
+ - model_path: The path to the model file.
+ - sources: A list of source configurations, each containing device_id and sources.
+ - tp: Tensor parallel size, indicating the number of devices for tensor parallelism.
+ - pp: Pipeline parallel size, indicating the number of devices for pipeline parallelism.
+
+ Returns:
+ - The loaded model if successful, otherwise None.
+ """
+
+ # Filter sources for the current device
+ sources_this_device = []
+ for s in sources:
+ if isinstance(
+ s, dict
+ ) and "device_id" in s and s["device_id"] == device_id and isinstance(
+ s["sources"], list):
+ sources_this_device += s["sources"]
+ if len(sources_this_device) == 0:
+ return None
+
+ try:
+ # Initialize the interaction layer with the ElasticClient
+ with ElasticClient(sources_this_device, device_id, model_path, tp,
+ pp) as client_interaction_layer:
+ if client_interaction_layer.s is None or client_interaction_layer.server_addr is None:
+ raise RuntimeError(
+ "Failed to initialize ElasticClient: socket or server_addr is None"
+ )
+ ack = client_interaction_layer.ack
+ if ack is None:
+ raise RuntimeError("ElasticClient.register did not return ack")
+
+ t0 = time.perf_counter()
+ elastic_loader = P2PLoad(ack[0],
+ client_interaction_layer.server_addr,
+ ack[1])
+ model_loaded = elastic_loader.load(model=model)
+ if model_loaded is None:
+ logger.error("Failed to load model")
+ return None
+ logger.info("Finish elastic load (duration: {}s)".format(
+ time.perf_counter() - t0))
+ return model_loaded
+ except Exception as e:
+ logger.error(f"elastic_load error: {e}")
+ return None
diff --git a/vllm_ascend/model_loader/netloader/netloader.py b/vllm_ascend/model_loader/netloader/netloader.py
new file mode 100644
index 00000000..9c2d8307
--- /dev/null
+++ b/vllm_ascend/model_loader/netloader/netloader.py
@@ -0,0 +1,324 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import gc
+import json
+import time
+from copy import deepcopy
+from typing import List, Optional, Tuple
+
+import torch
+from torch import nn
+from vllm.config import LoadConfig, ModelConfig, VllmConfig
+from vllm.logger import logger
+from vllm.model_executor.model_loader import register_model_loader
+from vllm.model_executor.model_loader.base_loader import BaseModelLoader
+from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
+from vllm.model_executor.model_loader.utils import (
+ initialize_model, process_weights_after_loading, set_default_torch_dtype)
+
+from .interaction.elastic import ElasticServer
+from .load import elastic_load
+from .utils import find_free_port, is_valid_path_prefix
+
+
+@register_model_loader("netloader")
+class ModelNetLoaderElastic(BaseModelLoader):
+ """
+ A model loader that uses elastic loading for loading weights.
+ """
+ source: Optional[List[dict]]
+ model_path: Optional[str]
+ listen_port: Optional[int]
+ int8_cache: str
+ int8_cache_name: Optional[List[str]]
+ output_prefix: Optional[str]
+
+ def __init__(self, load_config: LoadConfig):
+ """
+ Initializes the ModelNetLoaderElastic with configuration.
+
+ Parameters:
+ - load_config: Configuration for loading the model.
+ """
+ super().__init__(load_config)
+
+ config = None
+
+ # Try to read config file at first
+ extra = load_config.model_loader_extra_config
+ if extra and "CONFIG_FILE" in extra:
+ try:
+ logger.info(
+ f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..."
+ )
+ with open(extra["CONFIG_FILE"], 'r') as f:
+ config = json.load(f)
+ except FileNotFoundError:
+ logger.error("CONFIG_FILE not found")
+ except json.JSONDecodeError:
+ logger.error("CONFIG_FILE is not a valid JSON file")
+ except Exception as e:
+ logger.error(
+ f"Unexpected error while reading CONFIG_FILE: {e}")
+
+ if config is None and extra:
+ logger.info("Reading configs in model_loader_extra_config ...")
+ config = extra
+ config = config or {}
+
+ for key, attr, checker, caster, default in [
+ ("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v,
+ None),
+ ("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v,
+ None),
+ ("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or
+ (isinstance(v, str) and v.isdigit()), lambda v: int(v), None),
+ ("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v.
+ lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'),
+ ("INT8_CACHE_NAME", "int8_cache_name",
+ lambda v: isinstance(v, list), lambda v: v, None),
+ ("OUTPUT_PREFIX", "output_prefix",
+ lambda v: isinstance(v, str) and is_valid_path_prefix(v),
+ lambda v: v, None),
+ ]:
+ v = config.get(key, default)
+ if not checker(v):
+ v = default
+ else:
+ v = caster(v)
+ setattr(self, attr, v)
+
+ logger.info(
+ "Initializing elastic Netloader with config: "
+ "MODEL=%s, LISTEN_PORT=%s,"
+ "SOURCE=%s, INT8_CACHE=%s, INT8_CACHE_NAME=%s,"
+ "OUTPUT_PREFIX=%s)",
+ self.model_path,
+ self.listen_port,
+ self.source,
+ self.int8_cache,
+ self.int8_cache_name,
+ self.output_prefix,
+ )
+
+ def load_model(self, vllm_config: VllmConfig,
+ model_config: ModelConfig) -> nn.Module:
+ """
+ Loads the model using the specified configuration.
+
+ Parameters:
+ - vllm_config: Configuration for the VLLM.
+ - model_config: Configuration for the model.
+
+ Returns:
+ - The loaded model.
+ """
+
+ device_config = vllm_config.device_config
+ parallel_config = vllm_config.parallel_config
+
+ need_process_weights_after_loading = False
+
+ if self.model_path is None:
+ self.model_path = model_config.model
+ logger.info(f"model_path is set to {self.model_path}")
+
+ device_id = torch.distributed.get_rank()
+
+ if (self.source is None or not isinstance(self.source, list)
+ or device_id not in [
+ one_device["device_id"] for one_device in self.source if
+ isinstance(one_device, dict) and "device_id" in one_device
+ ]):
+ logger.warning(
+ "Did not get valid source info, use DefaultModelLoader")
+ model, need_process_weights_after_loading = self.revert_to_default(
+ model_config, vllm_config, device_config)
+
+ else:
+ target_device = torch.device(device_config.device)
+
+ vllm_config_backup = deepcopy(vllm_config)
+ model_config_backup = deepcopy(model_config)
+
+ with set_default_torch_dtype(model_config.dtype):
+ with target_device:
+ model = initialize_model(vllm_config=vllm_config,
+ model_config=model_config)
+
+ start_elastic_load = time.perf_counter()
+ model = elastic_load(
+ model=model,
+ device_id=device_id,
+ model_path=self.model_path,
+ sources=self.source,
+ tp=parallel_config.tensor_parallel_size,
+ pp=parallel_config.pipeline_parallel_size,
+ )
+ end_elastic_load = time.perf_counter()
+ logger.info(
+ f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}"
+ )
+ need_process_weights_after_loading = True
+
+ if model is None:
+ logger.warning(
+ "Netloader elastic loading fails, use load format DefaultModelLoader"
+ )
+
+ vllm_config = vllm_config_backup
+ model_config = model_config_backup
+
+ del model
+ gc.collect()
+ if device_config.device_type == 'npu':
+ logger.info("Empty NPU cache")
+ torch.npu.empty_cache()
+ elif device_config.device_type == 'cuda':
+ logger.info("Empty CUDA cache")
+ torch.cuda.empty_cache()
+
+ model, need_process_weights_after_loading = self.revert_to_default(
+ model_config, vllm_config, device_config)
+
+ start_elastic_server = time.perf_counter()
+ # start elastic server
+ if model is not None and (
+ (self.listen_port and self.listen_port in range(1024, 65535)) or
+ (self.listen_port is None)):
+ from vllm.utils import get_ip
+ driver_ip = get_ip()
+
+ if driver_ip == '0.0.0.0':
+ logger.error(
+ "Driver IP is not set, skip to start Netloader server")
+ else:
+ if self.listen_port is None:
+ self.listen_port = find_free_port()
+ else:
+ self.listen_port += device_id
+
+ logger.info(
+ f"Start elastic Netloader server, rank: {device_id}, listen port: {driver_ip}:{self.listen_port}"
+ )
+
+ if self.output_prefix is not None:
+ try:
+ with open(self.output_prefix + str(device_id) + '.txt',
+ 'w') as file:
+ file.write(f"{driver_ip}:{self.listen_port}")
+ logger.info(
+ f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}"
+ )
+ except FileNotFoundError:
+ logger.error(
+ f"File path {self.output_prefix + str(device_id)} does not exist."
+ )
+ except PermissionError:
+ logger.error(
+ f"No permission to write to file {self.output_prefix + str(device_id)}."
+ )
+ except IOError as e:
+ logger.error(
+ f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}"
+ )
+ except Exception as e:
+ logger.error(f"Unknown error: {e}")
+
+ try:
+ assert isinstance(
+ self.listen_port, int
+ ), f"listen port should be int but get {self.listen_port}"
+
+ elastic_server = ElasticServer(
+ driver_ip, self.listen_port, model, device_id,
+ self.model_path, parallel_config.tensor_parallel_size,
+ parallel_config.pipeline_parallel_size,
+ self.int8_cache, self.int8_cache_name)
+ elastic_server.start()
+ except Exception as e:
+ logger.error(
+ f"Failed to start Netloader server for rank: {device_id}, details: {e}"
+ )
+ else:
+ logger.info("Skip to start Netloader server")
+
+ end_elastic_server = time.perf_counter()
+ logger.info(
+ f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}"
+ )
+
+ if need_process_weights_after_loading:
+ process_weights_after_loading(model, model_config,
+ torch.device(device_config.device))
+
+ if model is None:
+ logger.error("NetLoader elastic loads model fails")
+ return None
+
+ return model.eval()
+
+ def revert_to_default(self, model_config, vllm_config,
+ device_config) -> Tuple[nn.Module, bool]:
+ """
+ Reverts to the default model loading logic when elastic loading fails or is not applicable.
+
+ This method resets the loader's extra config and load format to defaults,
+ then delegates model loading to a DefaultModelLoader.
+ If quantization is enabled, it will load the model and then run the
+ processing of weights (i.e. applying quantization adjustments) before returning.
+
+ Parameters:
+ - model_config: Configuration describing model architecture, quantization, etc.
+ - vllm_config: Configuration for vLLM (device, parallelism, dtype, etc).
+ - device_config: Configuration for the target device (device type, device id, etc).
+
+ Returns:
+ - A tuple (model, need_process_weights_after_loading):
+ * model: The loaded `nn.Module` under default loading logic.
+ * need_process_weights_after_loading: A boolean flag indicating whether
+ weights post-processing (e.g. quantization adjustments) still needs to be applied.
+ """
+ self.load_config.model_loader_extra_config = {}
+ self.load_config.load_format = "auto"
+ default_model_loader = DefaultModelLoader(self.load_config)
+
+ if model_config.quantization is None:
+ model = default_model_loader.load_model(vllm_config=vllm_config,
+ model_config=model_config)
+ need_process_weights_after_loading = False
+ else:
+ logger.warning(
+ "Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading "
+ )
+ need_process_weights_after_loading = True
+ target_device = torch.device(device_config.device)
+ with set_default_torch_dtype(model_config.dtype):
+ with target_device:
+ model = initialize_model(vllm_config=vllm_config,
+ model_config=model_config)
+ default_model_loader.load_weights(model, model_config)
+ model = model.eval()
+
+ return model, need_process_weights_after_loading
+
+ def download_model(self, model_config: ModelConfig) -> None:
+ pass
+
+ def load_weights(self, model: nn.Module,
+ model_config: ModelConfig) -> None:
+ pass
diff --git a/vllm_ascend/model_loader/netloader/utils.py b/vllm_ascend/model_loader/netloader/utils.py
new file mode 100644
index 00000000..fba5a58c
--- /dev/null
+++ b/vllm_ascend/model_loader/netloader/utils.py
@@ -0,0 +1,66 @@
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import re
+import socket
+
+from vllm.logger import logger
+
+
+def find_free_port():
+ """
+ Finds a free port on the local machine.
+
+ Returns:
+ - A free port number.
+ """
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.bind(('', 0))
+ return s.getsockname()[1]
+
+
+def is_valid_path_prefix(path_prefix):
+ """
+ Checks if the provided path prefix is valid.
+
+ Parameters:
+ - path_prefix: The path prefix to validate.
+
+ Returns:
+ - True if the path prefix is valid, otherwise False.
+ """
+ if not path_prefix:
+ return False
+
+ if re.search(r'[<>:"|?*]', path_prefix):
+ logger.warning(
+ f'The path prefix {path_prefix} contains illegal characters.')
+ return False
+
+ if path_prefix.startswith('/') or path_prefix.startswith('\\'):
+ if not os.path.exists(os.path.dirname(path_prefix)):
+ logger.warning(
+ f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.'
+ )
+ return False
+ else:
+ if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))):
+ logger.warning(
+ f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.'
+ )
+ return False
+ return True
\ No newline at end of file