diff --git a/docs/source/user_guide/feature_guide/images/netloader_flowchart.png b/docs/source/user_guide/feature_guide/images/netloader_flowchart.png new file mode 100644 index 00000000..4efe6b2b Binary files /dev/null and b/docs/source/user_guide/feature_guide/images/netloader_flowchart.png differ diff --git a/docs/source/user_guide/feature_guide/images/netloader_timing_diagram.png b/docs/source/user_guide/feature_guide/images/netloader_timing_diagram.png new file mode 100644 index 00000000..0af05812 Binary files /dev/null and b/docs/source/user_guide/feature_guide/images/netloader_timing_diagram.png differ diff --git a/docs/source/user_guide/feature_guide/index.md b/docs/source/user_guide/feature_guide/index.md index 049e496f..61c333b0 100644 --- a/docs/source/user_guide/feature_guide/index.md +++ b/docs/source/user_guide/feature_guide/index.md @@ -11,5 +11,6 @@ sleep_mode structured_output lora eplb_swift_balancer +netloader dynamic_batch ::: diff --git a/docs/source/user_guide/feature_guide/netloader.md b/docs/source/user_guide/feature_guide/netloader.md new file mode 100644 index 00000000..857f3473 --- /dev/null +++ b/docs/source/user_guide/feature_guide/netloader.md @@ -0,0 +1,97 @@ +# Netloader Guide + +This guide provides instructions for using **Netloader** as a weight-loader plugin for acceleration in **vLLM Ascend**. + +--- + +## Overview + +Netloader leverages high-bandwidth peer-to-peer (P2P) transfers between NPU cards to load model weights. It is implemented as a plugin (via the `register_model_loader` API added in vLLM 0.10). The workflow is: + +1. A **server** preloads a model. +2. A new **client** instance requests weight transfer. +3. After validating that the model and partitioning match, the client uses HCCL collective communication (send/recv) to receive weights in the same order as stored in the model. + +The server runs alongside normal inference tasks via sub-threads and via `stateless_init_torch_distributed_process_group` in vLLM. The client thus takes over weight initialization without needing to load from storage. + +### Flowchart + +![netloader flowchart](./images/netloader_flowchart.png) + +### Timing Diagram + +![netloader timing diagram](./images/netloader_timing_diagram.png) + +### Application Scenarios + +- **Reduce startup latency**: By reusing already loaded weights and transferring them directly between NPU cards, Netloader cuts down model loading time vs conventional remote/local pull strategies. +- **Relieve network & storage load**: Avoid repeated downloads of weight files from remote repositories, thus reducing pressure on central storage and network traffic. +- **Improve resource utilization & lower cost**: Faster loading allows less reliance on standby compute nodes; resources can be scaled up/down more flexibly. +- **Enhance business continuity & high availability**: In failure recovery, new instances can quickly take over without long downtime, improving system reliability and user experience. + +--- + +## Usage + +To enable Netloader, pass `--load-format=netloader` and provide configuration via `--model-loader-extra-config` (as a JSON string). Below are the supported configuration fields: + +| Field Name | Type | Description | Allowed Values / Notes | +|--------------------|---------|------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------| +| **SOURCE** | List | Weighted data sources. Each item is a map with `device_id` and `sources`, specifying the rank and its endpoints (IP:port).
Example: `{"SOURCE": [{"device_id": 0, "sources": ["10.170.22.152:19374"]}, {"device_id": 1, "sources": ["10.170.22.152:11228"]}]}`
If omitted or empty, fallback to default loader. The SOURCE here is second priority. | A list of objects with keys `device_id: int` and `sources: List[str]` | +| **MODEL** | String | The model name, used to verify consistency between client and server. | Defaults to the `--model` argument if not specified. | +| **LISTEN_PORT** | Integer | Base port for the server listener. | The actual port = `LISTEN_PORT + RANK`. If omitted, a random valid port is chosen. Valid range: 1024–65535. If out of range, that server instance won’t open a listener. | +| **INT8_CACHE** | String | Behavior for handling int8 parameters in quantized models. | One of `["hbm", "dram", "no"]`.
- `hbm`: copy original int8 parameters to high-bandwidth memory (HBM) (may cost a lot of HBM).
- `dram`: copy to DRAM.
- `no`: no special handling (may lead to divergence or unpredictable behavior). Default: `"no"`. | +| **INT8_CACHE_NAME** | List | Names of parameters to which `INT8_CACHE` is applied (i.e. filtering). | Default: `None` (means no filtering—all parameters). | +| **OUTPUT_PREFIX** | String | Prefix for writing per-rank listener address/port files in server mode. | If set, each rank writes to `{OUTPUT_PREFIX}{RANK}.txt` (text), content = `IP:Port`. | +| **CONFIG_FILE** | String | Path to a JSON file specifying the above configuration. | If provided, the SOURCE inside this file has **first priority** (overrides SOURCE in other configs). | + +--- + +## Example Commands & Placeholders + +> Replace parts in `` `<...>` `` before running. + +### Server + +```shell +VLLM_SLEEP_WHEN_IDLE=1 vllm serve `` \ + --tensor-parallel-size 1 \ + --served-model-name `` \ + --enforce-eager \ + --port `` \ + --load-format netloader +``` + +### Client + +```shell +export NETLOADER_CONFIG='{"SOURCE":[{"device_id":0, "sources": ["``:``"]}]}' + +VLLM_SLEEP_WHEN_IDLE=1 ASCEND_RT_VISIBLE_DEVICES=`` \ + vllm serve `` \ + --tensor-parallel-size 1 \ + --served-model-name `` \ + --enforce-eager \ + --port `` \ + --load-format netloader \ + --model-loader-extra-config="${NETLOADER_CONFIG}" +``` + +#### Placeholder Descriptions + +- ``: Path to the model file +- ``: Model name (must match between server & client) +- ``: Base listening port on server +- `` + ``: IP and port of the Netloader server (from server log) +- ``: Client device ID (must differ from server’s) +- ``: Port on which client listens + +After startup, you can test consistency by issuing inference requests with temperature = 0 and comparing outputs. + +--- + +## Note & Caveats + +- If Netloader is used, **each worker process** must bind a listening port. That port may be user-specified or assigned randomly. If user-specified, ensure it is available. +- Netloader requires extra HBM memory to establish HCCL connections (i.e. `HCCL_BUFFERSIZE`, default ~200 MB). Users should reserve sufficient capacity (e.g. via `--gpu-memory-utilization`). +- It is recommended to set `VLLM_SLEEP_WHEN_IDLE=1` to mitigate unstable or slow connections/transmissions. Related info: [vLLM Issue #16660](https://github.com/vllm-project/vllm/issues/16660), [vLLM PR #16226](https://github.com/vllm-project/vllm/pull/16226). diff --git a/setup.py b/setup.py index 0795304b..5a823e7a 100644 --- a/setup.py +++ b/setup.py @@ -393,7 +393,8 @@ setup( "vllm.platform_plugins": ["ascend = vllm_ascend:register"], "vllm.general_plugins": [ "ascend_enhanced_model = vllm_ascend:register_model", - "ascend_kv_connector = vllm_ascend:register_connector" + "ascend_kv_connector = vllm_ascend:register_connector", + "ascend_model_loader = vllm_ascend:register_model_loader" ], }, ) diff --git a/tests/ut/model_loader/netloader/test_netloader.py b/tests/ut/model_loader/netloader/test_netloader.py new file mode 100644 index 00000000..6658823a --- /dev/null +++ b/tests/ut/model_loader/netloader/test_netloader.py @@ -0,0 +1,215 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +from unittest.mock import MagicMock, patch + +import pytest +import torch +from torch import nn + +from vllm_ascend.model_loader.netloader.netloader import ModelNetLoaderElastic + + +class DummyDeviceConfig: + device = 'cuda' + device_type = 'cuda' + + +class DummyParallelConfig: + tensor_parallel_size = 1 + pipeline_parallel_size = 1 + + +class DummyVllmConfig: + device_config = DummyDeviceConfig() + parallel_config = DummyParallelConfig() + additional_config = None + + +class DummyModelConfig: + model = 'dummy-model' + dtype = torch.float32 + + +@pytest.fixture +def default_load_config(): + + class DummyLoadConfig: + model_loader_extra_config = None + load_format = "default" + + return DummyLoadConfig() + + +def make_loader_with_config(extra): + + class DummyLoadConfig: + model_loader_extra_config = extra + load_format = "default" + + return ModelNetLoaderElastic(DummyLoadConfig()) + + +def test_init_with_extra_config_file(tmp_path, monkeypatch): + # Generate test JSON file + config_content = { + "SOURCE": [{ + "device_id": 0 + }], + "MODEL": "foo-model", + "LISTEN_PORT": 5001, + "INT8_CACHE": "hbm", + "OUTPUT_PREFIX": str(tmp_path), + } + config_file = tmp_path / "config.json" + config_file.write_text(json.dumps(config_content)) + + dummy_logger = MagicMock() + monkeypatch.setattr("vllm.logger.logger", dummy_logger) + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", + lambda x: True) + + extra = {"CONFIG_FILE": str(config_file)} + loader = make_loader_with_config(extra) + assert loader.model_path == "foo-model" + assert loader.source == [{"device_id": 0}] + assert loader.listen_port == 5001 + assert loader.int8_cache == "hbm" + assert loader.output_prefix == str(tmp_path) + + +def test_init_with_extra_config(monkeypatch): + dummy_logger = MagicMock() + monkeypatch.setattr("vllm.logger.logger", dummy_logger) + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", + lambda x: True) + + extra = { + "SOURCE": [{ + "device_id": 0 + }], + "MODEL": "foo", + "LISTEN_PORT": "4000", + "INT8_CACHE": "dram", + "OUTPUT_PREFIX": "/tmp/" + } + loader = make_loader_with_config(extra) + assert loader.model_path == "foo" + assert loader.listen_port == 4000 + assert loader.int8_cache == "dram" + assert loader.output_prefix == "/tmp/" + assert loader.source == [{"device_id": 0}] + + +def test_init_with_invalid_config(monkeypatch): + dummy_logger = MagicMock() + monkeypatch.setattr("vllm.logger.logger", dummy_logger) + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix", + lambda x: False) + # c + extra = { + "SOURCE": None, + "MODEL": None, + "LISTEN_PORT": None, + "INT8_CACHE": "something", + "OUTPUT_PREFIX": None, + } + loader = make_loader_with_config(extra) + assert loader.model_path is None + assert loader.listen_port is None + assert loader.int8_cache == "no" + assert loader.output_prefix is None + + +@patch("vllm_ascend.model_loader.netloader.netloader.logger") +def test_load_model_elastic_success(mock_logger, monkeypatch, tmp_path): + monkeypatch.setattr("torch.distributed.get_rank", lambda: 0) + + class FakeContext: + + def __enter__(self): + pass + + def __exit__(self, a, b, c): + pass + + monkeypatch.setattr("torch.device", lambda d: FakeContext()) + # patch deep copy + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x) + # patch set_default_torch_dtype + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype", + lambda dtype: FakeContext()) + # patch initialize_model + dummy_model = MagicMock(spec=nn.Module) + dummy_model.eval.return_value = dummy_model + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.initialize_model", + lambda **kwargs: dummy_model) + # patch elastic_load + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.elastic_load", + lambda **kwargs: dummy_model) + # patch process_weights_after_loading + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading", + lambda *a, **k: None) + # patch get_ip + monkeypatch.setattr("vllm.utils.get_ip", lambda: "127.0.0.1") + # patch find_free_port + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.find_free_port", + lambda: 8888) + + # patch ElasticServer + class DummyElasticServer: + + def __init__(*a, **k): + pass + + def start(self): + pass + + monkeypatch.setattr( + "vllm_ascend.model_loader.netloader.netloader.ElasticServer", + DummyElasticServer) + # write output_prefix to the temporary directory + extra = { + "SOURCE": [{ + "device_id": 0 + }], + "MODEL": "foo", + "LISTEN_PORT": 5555, + "OUTPUT_PREFIX": str(tmp_path) + "/output_", + "INT8_CACHE": "no" + } + loader = make_loader_with_config(extra) + vllm_config = DummyVllmConfig() + model_config = DummyModelConfig() + result = loader.load_model(vllm_config, model_config) + assert isinstance(result, nn.Module) + # Check file + written_file = tmp_path / "output_0.txt" + assert written_file.exists() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/ut/model_loader/netloader/test_netloader_elastic.py b/tests/ut/model_loader/netloader/test_netloader_elastic.py new file mode 100644 index 00000000..127f1dd6 --- /dev/null +++ b/tests/ut/model_loader/netloader/test_netloader_elastic.py @@ -0,0 +1,432 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import io +import json +import logging +import socket +from unittest.mock import MagicMock, patch + +import pytest +import torch +import vllm.logger + +from vllm_ascend.model_loader.netloader.interaction.elastic import ( + ElasticClient, ElasticServer) + + +# Simulate server's normal response +def mock_server_response(data): + return json.dumps({ + "label": "JOIN_ACK", + "content": { + "name": "mocked_name" + } + }).encode("utf-8") + + +# Simulate server's error response +def mock_server_error_response(data): + return json.dumps({"label": "JOIN_ACK", "content": None}).encode("utf-8") + + +# Simulated server's abnormal response +def mock_server_exception_response(data): + raise Exception("Mocked server exception") + + +# Test the initialization of ElasticClient +def test_elastic_client_init(): + sources = ["127.0.0.1:12345"] + device_id = 0 + model_path = "mocked_model_path" + tp = 1 + pp = 1 + + with patch('socket.socket') as mock_socket: + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.recv.return_value = mock_server_response(None) + + mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346) + mock_socket_instance.__enter__.return_value = mock_socket_instance + + with ElasticClient(sources, device_id, model_path, tp, pp) as client: + assert client.server_addr == "127.0.0.1" + assert client.server_port == 12345 + assert client.ack == ("mocked_name", 12346) + mock_socket_instance.close.assert_called_once() + + +# Test the register method of ElasticClient +def test_elastic_client_register(): + sources = ["127.0.0.1:12345"] + device_id = 0 + model_path = "mocked_model_path" + tp = 1 + pp = 1 + + with patch('socket.socket') as mock_socket: + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.connect.return_value = None + mock_socket_instance.recv.return_value = mock_server_response(None) + + mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346) + mock_socket_instance.__enter__.return_value = mock_socket_instance + + client = ElasticClient(sources, device_id, model_path, tp, pp) + assert client.register(device_id, model_path, tp, + pp) == ("mocked_name", 12346) + + +# Test the behavior of the `register` method of ElasticClient when the server returns an error response. +def test_elastic_client_register_error_response(): + sources = ["127.0.0.1:12345"] + device_id = 0 + model_path = "mocked_model_path" + tp = 1 + pp = 1 + + with patch('socket.socket') as mock_socket: + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.connect.return_value = None + mock_socket_instance.recv.return_value = mock_server_error_response( + None) + + with ElasticClient(sources, device_id, model_path, tp, pp) as client: + with pytest.raises(RuntimeError): + client.register(device_id, model_path, tp, pp) + mock_socket_instance.close.assert_called_once() + + +# Test the behavior of the `register` method of ElasticClient when an exception is thrown on the server. +def test_elastic_client_register_exception(): + sources = ["127.0.0.1:12345"] + device_id = 0 + model_path = "mocked_model_path" + tp = 1 + pp = 1 + + with patch('socket.socket') as mock_socket: + mock_socket_instance = MagicMock() + mock_socket.return_value = mock_socket_instance + mock_socket_instance.connect.return_value = None + mock_socket_instance.recv.side_effect = mock_server_exception_response + mock_socket_instance.__enter__.return_value = mock_socket_instance + mock_socket_instance.__exit__.return_value = None + + with ElasticClient(sources, device_id, model_path, tp, pp) as client: + with pytest.raises(RuntimeError): + client.register(device_id, model_path, tp, pp) + mock_socket_instance.close.assert_called_once() + + +class FakeInt8Param: + + def __init__(self, name="param", device="npu", dtype=torch.int8): + self.dtype = dtype + self.device = torch.device(device) + + @property + def data(self): + return self # Simulate .data returning self so .cpu() etc. can be chained + + def clone(self): + return self + + def detach(self): + return self + + def cpu(self): + self.device = torch.device("cpu") + return self + + +class FakeModel: + + def __init__(self): + self.params = { + "param1": MagicMock(dtype=torch.float32), # This will be ignored + "param2": FakeInt8Param(), # This simulates a real int8 param + } + + def named_parameters(self): + return self.params.items() + + +@pytest.fixture +def mock_model(): + return FakeModel() + + +@pytest.fixture +def server_config(): + return { + "addr": "127.0.0.1", + "port": 8080, + "model": MagicMock(), + "device_id": 0, + "model_path": "/test/model", + "tp": 1, + "pp": 1, + "int8_cache": "dram", + 'int8_cache_name': None + } + + +# Test server initialization +def test_server_initialization(server_config, mock_model): + server_config["model"] = mock_model + with patch("socket.socket") as mock_socket: + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + ch.setLevel(logging.DEBUG) + vllm.logger.logger.addHandler(ch) + + server = ElasticServer(**server_config) + + # Check the socket configuration + mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM) + mock_socket.return_value.setsockopt.assert_called_with( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080)) + mock_socket.return_value.listen.assert_called_with(256) + + # Check int8 cache + assert "param2" in server.original_int8 + assert server.original_int8[ + "param2"].device.type == "cpu" # Verifying DRAM Cache + + assert server.addr == server_config['addr'] + assert server.port == server_config['port'] + assert server.device_id == server_config['device_id'] + assert server.model_path == server_config['model_path'] + assert server.tp == server_config['tp'] + assert server.pp == server_config['pp'] + + # Get captured logs + log_output = log_capture_string.getvalue() + vllm.logger.logger.removeHandler(ch) + log_capture_string.close() + + # Check output + assert "Server 127.0.0.1:8080 starts" in log_output + + +# Test the int8 cache option +@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"), + ("no", None), + ("invalid", None)]) +def test_int8_cache_handling(server_config, mock_model, cache_option, + expected_device, caplog): + server_config["int8_cache"] = cache_option + server_config["model"] = mock_model + + with patch("socket.socket"): + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + ch.setLevel(logging.DEBUG) + vllm.logger.logger.addHandler(ch) + + server = ElasticServer(**server_config) + + log_output = log_capture_string.getvalue() + vllm.logger.logger.removeHandler(ch) + log_capture_string.close() + + if cache_option == "invalid": + assert "int8_cache should be selected in [HBM, DRAM]" in log_output + + if expected_device is None: + assert len(server.original_int8) == 0 + else: + assert server.original_int8[ + "param2"].device.type == expected_device + + +# Test client processing +def test_client_handler_valid_join(server_config, mock_model): + server_config["model"] = mock_model + with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend" + ) as mock_p2p_send: + + # Create a simulated connection + mock_conn = MagicMock() + mock_addr = ("192.168.1.1", 12345) + + # Configuring Client Data + valid_data = { + "label": "JOIN", + "content": { + "device_id": 0, + "model_path": "/test/model", + "tp": 1, + "pp": 1, + "port": 9090 + } + } + mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8") + + # Start the server + server = ElasticServer(**server_config) + server.register_handler(mock_conn, mock_addr) + + # Verify response + expected_ack = { + "label": "JOIN_ACK", + "content": { + "name": "192.168.1.1:12345" + } + } + mock_conn.send.assert_called_once_with( + json.dumps(expected_ack).encode("utf-8")) + mock_p2p_send.assert_called_once_with("127.0.0.1", 9090, + "192.168.1.1:12345") + mock_conn.close.assert_called_once() + + +# Test mismatched JOIN requests +def test_client_handler_mismatch(server_config): + with patch("socket.socket"): + server = ElasticServer(**server_config) + mock_conn = MagicMock() + mock_addr = ("192.168.1.1", 12345) + + # Send mismatched data + mismatch_data = { + "label": "JOIN", + "content": { + "device_id": 1, # 不匹配的ID + "model_path": "/wrong/model", + "tp": 2, + "pp": 2, + "port": 9090 + } + } + mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8") + + server.register_handler(mock_conn, mock_addr) + + assert isinstance(mismatch_data["content"], dict) + + # Verify response + expected_ack = { + "label": + "JOIN_NACK", + "content": + f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}" + } + mock_conn.send.assert_called_once_with( + json.dumps(expected_ack).encode("utf-8")) + mock_conn.close.assert_called_once() + + +# Test Invalid Request +@pytest.mark.parametrize( + "invalid_data,should_send", + [ + ( + { + "label": "WRONG_LABEL" + }, True + ), # Incorrect label, can be decoded as JSON, but the content is invalid. + ( + { + "content": { + "missing_fields": True + } + }, True + ), # Missing field, can be decoded as JSON, but the content is invalid. + ("plain text", False), # Non-JSON data, json.loads failed + (b"invalid_bytes", False) # Invalid byte, decode or json.loads failed + ]) +def test_client_handler_invalid_requests(server_config, invalid_data, + should_send): + with patch("socket.socket"): + log_capture_string = io.StringIO() + ch = logging.StreamHandler(log_capture_string) + ch.setLevel(logging.DEBUG) + vllm.logger.logger.addHandler(ch) + + with patch("socket.socket"): + server = ElasticServer(**server_config) + mock_conn = MagicMock() + mock_addr = ("192.168.1.1", 12345) + + if isinstance(invalid_data, (str, bytes)): + mock_conn.recv.return_value = invalid_data if isinstance( + invalid_data, bytes) else invalid_data.encode() + else: + mock_conn.recv.return_value = json.dumps(invalid_data).encode( + "utf-8") + + server.register_handler(mock_conn, mock_addr) + + if should_send: + expected_ack = { + "label": + "JOIN_NACK", + "content": + f"Received data does not contain required fields: {invalid_data}" + } + mock_conn.send.assert_called_once_with( + json.dumps(expected_ack).encode("utf-8")) + else: + mock_conn.send.assert_not_called() + + log_output = log_capture_string.getvalue() + vllm.logger.logger.removeHandler(ch) + log_capture_string.close() + + # Any warning in the log is acceptable + assert "Failed to load" in log_output or "does not contain" in log_output + mock_conn.close.assert_called_once() + + +# Test the thread startup. +def test_server_start(server_config): + with patch("socket.socket"), \ + patch("threading.Thread") as mock_thread: + + handler_thread_instance = mock_thread.return_value + + server = ElasticServer(**server_config) + server.start() + + # Assert that the correct target parameter was passed when instantiating the Thread instance. + mock_thread.assert_called_once() + args, kwargs = mock_thread.call_args + assert kwargs['target'] == server.elastic_client_handler + + # Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment). + assert handler_thread_instance.daemon is True + + # Check if the start() method is called. + handler_thread_instance.start.assert_called_once() + + +# Test resource clearing +def test_server_cleanup(server_config): + with patch("socket.socket") as mock_socket: + server = ElasticServer(**server_config) + del server + mock_socket.return_value.close.assert_called_once() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/ut/model_loader/netloader/test_netloader_load.py b/tests/ut/model_loader/netloader/test_netloader_load.py new file mode 100644 index 00000000..77c3486f --- /dev/null +++ b/tests/ut/model_loader/netloader/test_netloader_load.py @@ -0,0 +1,114 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest.mock import MagicMock, patch + +import pytest + +from vllm_ascend.model_loader.netloader.load import elastic_load + + +@pytest.fixture +def mock_sources(): + return [ + { + "device_id": 0, + "sources": ["a", "b"] + }, + { + "device_id": 1, + "sources": ["c"] + }, + ] + + +@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient") +@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad") +def test_sources_this_device_empty(mock_p2p, mock_client): + sources = [{"device_id": 1, "sources": ["c"]}] + result = elastic_load("model", 0, "model_path", sources, 1, 1) + assert result is None + mock_client.assert_not_called() + mock_p2p.assert_not_called() + + +@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient") +@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad") +def test_client_s_none(mock_p2p, mock_client, mock_sources): + # Simulate ElasticClient.s as None + mock_instance = MagicMock() + mock_instance.s = None + mock_client.return_value = mock_instance + result = elastic_load("model", 0, "model_path", mock_sources, 1, 1) + assert result is None + + +@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient") +@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad") +def test_client_ack_none(mock_p2p, mock_client, mock_sources): + # Simulate ElasticClient.ack as None + mock_instance = MagicMock() + mock_instance.s = True + mock_instance.ack = None + mock_client.return_value = mock_instance + result = elastic_load("model", 0, "model_path", mock_sources, 1, 1) + assert result is None + + +@patch("vllm_ascend.model_loader.netloader.load.P2PLoad") +@patch("vllm_ascend.model_loader.netloader.load.logger") +def test_model_load_fail(mock_logger, mock_p2p): + mock_client = MagicMock() + mock_client.s = True + mock_client.ack = ["foo", "bar"] + mock_client.server_addr = "addr" + + with patch("vllm_ascend.model_loader.netloader.load.ElasticClient", + return_value=mock_client): + # P2PLoad.load returns None + mock_p2p_instance = MagicMock() + mock_p2p_instance.load.return_value = None + mock_p2p.return_value = mock_p2p_instance + + sources = [{"device_id": 0, "sources": ["whatever"]}] + result = elastic_load("model", 0, "model_path", sources, 1, 1) + assert result is None + mock_logger.error.assert_called_once() + + +@patch("vllm_ascend.model_loader.netloader.load.P2PLoad") +@patch("vllm_ascend.model_loader.netloader.load.logger") +def test_model_load_success(mock_logger, mock_p2p): + mock_client = MagicMock() + mock_client.s = True + mock_client.ack = ["foo", "bar"] + mock_client.server_addr = "addr" + + with patch("vllm_ascend.model_loader.netloader.load.ElasticClient", + return_value=mock_client): + expected_model = object() + mock_p2p_instance = MagicMock() + mock_p2p_instance.load.return_value = expected_model + mock_p2p.return_value = mock_p2p_instance + + sources = [{"device_id": 0, "sources": ["whatever"]}] + result = elastic_load("model", 0, "model_path", sources, 1, 1) + assert result is expected_model + mock_logger.info.assert_called_once() + + +if __name__ == "__main__": + pytest.main() diff --git a/tests/ut/model_loader/netloader/test_netloader_utils.py b/tests/ut/model_loader/netloader/test_netloader_utils.py new file mode 100644 index 00000000..66198449 --- /dev/null +++ b/tests/ut/model_loader/netloader/test_netloader_utils.py @@ -0,0 +1,61 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import tempfile +from unittest.mock import patch + +import pytest + +from vllm_ascend.model_loader.netloader.utils import (find_free_port, + is_valid_path_prefix) + + +def test_find_free_port(): + port = find_free_port() + assert isinstance(port, int) + assert port > 0 + + +def test_is_valid_path_prefix_empty(): + assert not is_valid_path_prefix('') + + +def test_is_valid_path_prefixIllegal_characters(): + assert not is_valid_path_prefix('test<>:"|?*') + + +def test_is_valid_path_prefixRelative_path(): + assert is_valid_path_prefix('test') + + +def test_is_valid_path_prefixAbsolute_path(): + with tempfile.TemporaryDirectory() as tmpdir: + assert is_valid_path_prefix(os.path.join(tmpdir, 'test')) + + +@patch('os.path.exists', return_value=False) +def test_is_valid_path_prefix_no_directory(mock_exists): + assert not is_valid_path_prefix('/nonexistent_dir/test') + + +@patch('os.path.exists', return_value=True) +def test_is_valid_path_prefix_directory_exists(mock_exists): + assert is_valid_path_prefix('/existing_dir/test') + + +if __name__ == "__main__": + pytest.main() diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index 74a81537..aa72f74b 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -31,3 +31,8 @@ def register_model(): def register_connector(): from vllm_ascend.distributed import register_connector register_connector() + + +def register_model_loader(): + from .model_loader.netloader import register_netloader + register_netloader() \ No newline at end of file diff --git a/vllm_ascend/model_loader/__init__.py b/vllm_ascend/model_loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/model_loader/netloader/__init__.py b/vllm_ascend/model_loader/netloader/__init__.py new file mode 100644 index 00000000..f3b7c50a --- /dev/null +++ b/vllm_ascend/model_loader/netloader/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def register_netloader(): + """Register the NetLoader plugin.""" + from .netloader import ModelNetLoaderElastic # noqa diff --git a/vllm_ascend/model_loader/netloader/executor/__init__.py b/vllm_ascend/model_loader/netloader/executor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/model_loader/netloader/executor/elastic_load.py b/vllm_ascend/model_loader/netloader/executor/elastic_load.py new file mode 100644 index 00000000..850bfaf9 --- /dev/null +++ b/vllm_ascend/model_loader/netloader/executor/elastic_load.py @@ -0,0 +1,170 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +import torch_npu +from vllm.distributed.utils import ( + stateless_destroy_torch_distributed_process_group, + stateless_init_torch_distributed_process_group) +from vllm.logger import logger + + +class P2PLoad: + """ + Class for receiving model parameters in a distributed manner using HCCL backend. + """ + + def __init__( + self, + world_name: str, + source_ip: str, + source_port: int, + ): + """ + Initializes the P2PLoad instance. + + Parameters: + - world_name: The name of the distributed group. + - source_ip: The IP address of the source node. + - source_port: The port number for the source node. + """ + self.world_name = world_name + self.source_ip = source_ip + self.source_port = source_port + + def load(self, model): + """ + Loads the model parameters using HCCL backend. + + Parameters: + - model: The model whose parameters are to be loaded. + + Returns: + - The model if loading is successful, otherwise None. + """ + model_device = next(model.parameters()).device + logger.info( + f"Start init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" + ) + receiver_pg = None + loaded_model = None + try: + receiver_pg = stateless_init_torch_distributed_process_group( + host=self.world_name.split(":")[0], + port=self.source_port, + rank=0, + world_size=2, + backend='hccl', + ) + logger.info( + f"Finish init_process_group, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" + ) + + logger.info( + f"Start recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" + ) + logger.info(f"Model device: {model_device}") + + trans_stream = torch_npu.npu.Stream() + with torch_npu.npu.stream(trans_stream): + for name, param in model.named_parameters(): + if len(param.shape) == 0: + continue + receiver_pg.recv([param], 1, 0).wait() + torch.distributed.barrier(group=receiver_pg, + device_ids=[model_device.index]) + + torch_npu.npu.synchronize(trans_stream) + + logger.info( + f"Finish recv, name: {self.world_name}, addr: {self.source_ip}:{self.source_port}" + ) + loaded_model = model + except Exception as e: + logger.error("Failed to recv model: {}".format(e)) + finally: + if receiver_pg: + stateless_destroy_torch_distributed_process_group(receiver_pg) + return loaded_model + + +class P2PSend: + """ + Class for sending model parameters in a distributed manner using HCCL backend. + """ + + def __init__(self, listen_ip: str, listen_port: int, comm_name: str): + """ + Initializes the P2PSend instance. + + Parameters: + - listen_ip: The IP address to listen on. + - listen_port: The port number to listen on. + - comm_name: The name of the communication group. + """ + self.listen_ip = listen_ip + self.listen_port = listen_port + self.comm_name = comm_name + + def send(self, model, int8_params: dict): + """ + Sends the model parameters using HCCL backend. + + Parameters: + - model: The model whose parameters are to be sent. + - int8_params: Dictionary of parameters that are in int8 format. + """ + model_device = next(model.parameters()).device + torch.npu.set_device(model_device) + logger.info( + f"Start init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" + ) + sender_pg = None + try: + sender_pg = stateless_init_torch_distributed_process_group( + host=self.comm_name.split(":")[0], + port=self.listen_port, + rank=1, + world_size=2, + backend='hccl', + ) + logger.info( + f"Finish init_process_group, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" + ) + logger.info( + f"Start send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" + ) + logger.info(f"Model device: {model_device}") + + trans_stream = torch_npu.npu.Stream() + with torch_npu.npu.stream(trans_stream): + for name, param in model.named_parameters(): + if "aclnn_input_scale" in name: + continue + if name in int8_params: + sender_pg.send([int8_params[name].to(model_device)], 0, + 0).wait() + else: + sender_pg.send([param.contiguous()], 0, 0).wait() + torch.distributed.barrier(group=sender_pg, + device_ids=[model_device.index]) + torch_npu.npu.synchronize(trans_stream) + logger.info( + f"Finish send, name: {self.comm_name}, addr: {self.listen_ip}:{self.listen_port}" + ) + finally: + if sender_pg: + stateless_destroy_torch_distributed_process_group(sender_pg) diff --git a/vllm_ascend/model_loader/netloader/interaction/__init__.py b/vllm_ascend/model_loader/netloader/interaction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/model_loader/netloader/interaction/elastic.py b/vllm_ascend/model_loader/netloader/interaction/elastic.py new file mode 100644 index 00000000..1000bd7c --- /dev/null +++ b/vllm_ascend/model_loader/netloader/interaction/elastic.py @@ -0,0 +1,408 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import re +import socket +import threading +from typing import List, Optional, Tuple + +import torch +from vllm.logger import logger + +from ..executor.elastic_load import P2PSend +from ..utils import find_free_port + + +class ElasticClient: + """ + Class for handling the client-side logic of Netloader of models. + """ + + def __init__(self, sources: list[str], device_id: int, model_path: str, + tp: int, pp: int): + """ + Initializes the ElasticClient instance. + + Parameters: + - sources: List of source addresses in the format IP:port. + - device_id: The ID of the current device. + - model_path: The path to the model. + - tp: Tensor parallel size. + - pp: Pipeline parallel size. + """ + self.sources = sources + self.device_id = device_id + self.model_path = model_path + self.tp = tp + self.pp = pp + + self.s: Optional[socket.socket] = None + self.ack: Optional[Tuple[str, int]] = None + self.server_addr: Optional[str] = None + self.server_port: Optional[int] = None + + for source in self.sources: + try: + ip, port_str = source.split(':') + port = int(port_str) + except Exception as e: + logger.error(f"IP format error: {source}, detail: {e}") + continue + + self.server_addr = ip + self.server_port = port + + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + logger.info( + f"Start connection to server: {self.server_addr}:{self.server_port}" + ) + sock.connect((self.server_addr, self.server_port)) + logger.info( + f"Finish connection to server: {self.server_addr}:{self.server_port}" + ) + sock.settimeout(60) + + self.s = sock + self.ack = self.register(device_id, model_path, tp, pp) + break + except Exception as e: + logger.error(f"Connect to {source} fails, detail: {e}") + if sock is not None: + try: + sock.close() + except Exception: + pass + self.s = None + self.ack = None + self.server_addr = None + self.server_port = None + + def close(self) -> None: + """ + Closes the socket connection. + """ + if self.s is not None: + try: + self.s.close() + except Exception as e: + logger.error(f"Error closing socket: {e}") + finally: + self.s = None + + def __enter__(self) -> "ElasticClient": + """ + Context manager enter method. + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """ + Context manager exit method. + """ + self.close() + + def __del__(self): + """ + Destructor method to ensure socket is closed. + """ + try: + self.close() + except Exception: + pass + + def send_str(self, data_str: str) -> None: + """ + Sends a string over the socket connection. + + Parameters: + - data_str: The string to be sent. + """ + if self.s is None: + raise RuntimeError("Socket was not created correctly.") + self.s.send(data_str.encode("utf-8")) + + def recv_str(self, buffer_size: int = 1024) -> str: + """ + Receives a string over the socket connection. + + Parameters: + - buffer_size: The size of the buffer for receiving data. + + Returns: + - The received string. + """ + if self.s is None: + raise RuntimeError("Socket was not created correctly.") + data_str = self.s.recv(buffer_size).decode("utf-8") + return data_str + + def register(self, device_id: int, model_path: str, tp: int, + pp: int) -> Tuple[str, int]: + """ + Registers the client with the server. + + Parameters: + - device_id: The ID of the current device. + - model_path: The path to the model. + - tp: Tensor parallel size. + - pp: Pipeline parallel size. + + Returns: + - A tuple containing the communication name and port. + """ + free_port = find_free_port() + data = { + "label": "JOIN", + "content": { + 'device_id': device_id, + 'model_path': model_path, + 'tp': tp, + 'pp': pp, + 'port': free_port + } + } + + try: + self.send_str(json.dumps(data)) + except Exception as e: + raise RuntimeError( + f"Send data {data} to server fails, detail: {e}") + + try: + ack_str = self.recv_str() + except Exception as e: + raise RuntimeError(f"Receive data from server fails, detail: {e}") + + try: + ack = json.loads(ack_str) + except Exception as e: + raise RuntimeError( + f"Receive data {ack_str} cannot be converted to JSON format, detail: {e}" + ) + + logger.info(f"Receive ack: {ack}") + + if ("label" in ack and ack["label"] == 'JOIN_ACK' and "content" in ack + and ack["content"] is not None and "name" in ack["content"]): + return (ack["content"]["name"], free_port) + elif ("label" in ack and ack["label"] == 'JOIN_NACK' + and "content" in ack): + raise RuntimeError( + f"Receive nack from server, reason: {ack['content']}") + else: + raise RuntimeError( + f"Receive ack {ack} from server does not contain required fields" + ) + + +class ElasticServer: + """ + Class for handling the server-side logic of Netloader of models. + """ + + def __init__(self, addr: str, port: int, model, device_id: int, + model_path: str, tp: int, pp: int, int8_cache: str, + int8_cache_name: Optional[List[str]]): + """ + Initializes the ElasticServer instance. + + Parameters: + - addr: The IP address to listen on. + - port: The port number to listen on. + - model: The model to be served. + - device_id: The ID of the current device (i.e. global rank). + - model_path: The path to the model. + - tp: Tensor parallel size. + - pp: Pipeline parallel size. + - int8_cache: The type of caching for int8 parameters (HBM, DRAM, or no). + - int8_cache_name: List of parameter names to be cached. + """ + self.addr = addr + self.port = port + self.s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self.s.bind((self.addr, self.port)) + self.s.listen(256) + + self.model = model + self.device_id = device_id + self.model_path = model_path + self.tp = tp + self.pp = pp + + self.original_int8 = {} + int8_pattern = "|".join( + map(re.escape, + int8_cache_name)) if int8_cache_name is not None else "(?:)" + for name, param in self.model.named_parameters(): + if param.dtype == torch.int8: + if int8_cache == 'hbm': + if int8_cache_name is None or ( + int8_cache_name is not None + and re.search(int8_pattern, name) is not None): + try: + self.original_int8[name] = param.data.clone( + ).detach() + except RuntimeError as e: + logger.error( + f"Failed to cache int8 tensor {name} to HBM, change to DRAM, due to {e}" + ) + self.original_int8[name] = param.data.cpu() + + elif int8_cache == 'dram': + if int8_cache_name is None or ( + int8_cache_name is not None + and re.search(int8_pattern, name) is not None): + self.original_int8[name] = param.data.cpu() + elif int8_cache == 'no': + pass + else: + logger.warning( + f"int8_cache should be selected in [HBM, DRAM], but got {int8_cache}, change to no cache" + ) + + logger.info( + f"Server {self.addr}:{self.port} starts, device id: {self.device_id}, model path: {self.model_path}, tp: {self.tp}, pp: {self.pp}, int8 params {self.original_int8.keys()} are saved to {int8_cache}" + ) + + def __del__(self): + """ + Destructor method to ensure socket is closed. + """ + self.s.close() + + def start(self): + """ + Starts the server to handle incoming connections. + """ + handler_thread = threading.Thread(target=self.elastic_client_handler) + handler_thread.daemon = True + handler_thread.start() + + def elastic_client_handler(self): + """ + Handles incoming client connections. + """ + while True: + conn, addr = self.s.accept() + logger.info("Accept new connection from {}:{}...".format(*addr)) + self.register_handler(conn, addr) + + def register_handler(self, conn, addr, buffer_size=1024): + """ + Handles the registration of a client. + + Parameters: + - conn: The connection socket. + - addr: The address of the client. + - buffer_size: The size of the buffer for receiving data. + """ + data_str = conn.recv(buffer_size).decode("utf-8") + if not data_str: + return + try: + data = json.loads(data_str) + except Exception: + logger.error(f"Failed to load {data_str} as JSON string") + conn.close() + return + + def is_valid_data(data): + """ + Validates the received data. + + Parameters: + - data: The data to be validated. + + Returns: + - True if the data is valid, otherwise False. + """ + if not isinstance(data, dict): + return False + if data.get("label") != "JOIN": + return False + content = data.get("content") + if not isinstance(content, dict): + return False + required_keys = ["device_id", "model_path", "tp", "pp", "port"] + if not all(k in content for k in required_keys): + return False + port = content["port"] + if not (isinstance(port, int) or + (isinstance(port, str) and port.isdigit())): + return False + return True + + comm_name = None + if is_valid_data(data): + device_id = int(data["content"]["device_id"]) + model_path = data["content"]["model_path"] + tp = int(data["content"]["tp"]) + pp = int(data["content"]["pp"]) + + if int(self.device_id + ) == device_id and self.model_path == model_path and int( + self.tp) == tp and int(self.pp) == pp: + comm_name = str(addr[0]) + ":" + str(addr[1]) + ack = {"label": "JOIN_ACK", "content": {"name": comm_name}} + else: + logger.warning( + f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}" + ) + ack = { + "label": + "JOIN_NACK", + "content": + f"Received data {(device_id, model_path, tp, pp)} does not consist with this server {(int(self.device_id), self.model_path, int(self.tp), int(self.pp))}" + } + else: + logger.warning( + f"Received data does not contain required fields: {data}") + ack = { + "label": + "JOIN_NACK", + "content": + f"Received data does not contain required fields: {data}" + } + + try: + ack_str = json.dumps(ack).encode("utf-8") + except Exception as e: + logger.error( + f"Failed to convert {ack} to JSON format, details: {e}") + conn.close() + return + + try: + conn.send(ack_str) + except Exception as e: + logger.error(f"Failed to send {ack} to {addr}, details: {e}") + conn.close() + return + + if ack["content"] and isinstance(ack["content"], + dict) and 'name' in ack["content"]: + try: + p2psend = P2PSend(self.addr, data["content"]["port"], + ack["content"]["name"]) + p2psend.send(self.model, self.original_int8) + except Exception as e: + logger.error( + f"P2PSend Failed to send model to {self.addr}, details: {e}" + ) + conn.close() diff --git a/vllm_ascend/model_loader/netloader/load.py b/vllm_ascend/model_loader/netloader/load.py new file mode 100644 index 00000000..90000d58 --- /dev/null +++ b/vllm_ascend/model_loader/netloader/load.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time + +from vllm.logger import logger + +from .executor.elastic_load import P2PLoad +from .interaction.elastic import ElasticClient + + +def elastic_load( + model, + device_id: int, + model_path: str, + sources: list, + tp: int, + pp: int, +): + """ + Loads a model using elastic loading across multiple devices. + + Parameters: + - model: The model instance to be loaded. + - device_id: The ID of the current device (i.e. global rank). + - model_path: The path to the model file. + - sources: A list of source configurations, each containing device_id and sources. + - tp: Tensor parallel size, indicating the number of devices for tensor parallelism. + - pp: Pipeline parallel size, indicating the number of devices for pipeline parallelism. + + Returns: + - The loaded model if successful, otherwise None. + """ + + # Filter sources for the current device + sources_this_device = [] + for s in sources: + if isinstance( + s, dict + ) and "device_id" in s and s["device_id"] == device_id and isinstance( + s["sources"], list): + sources_this_device += s["sources"] + if len(sources_this_device) == 0: + return None + + try: + # Initialize the interaction layer with the ElasticClient + with ElasticClient(sources_this_device, device_id, model_path, tp, + pp) as client_interaction_layer: + if client_interaction_layer.s is None or client_interaction_layer.server_addr is None: + raise RuntimeError( + "Failed to initialize ElasticClient: socket or server_addr is None" + ) + ack = client_interaction_layer.ack + if ack is None: + raise RuntimeError("ElasticClient.register did not return ack") + + t0 = time.perf_counter() + elastic_loader = P2PLoad(ack[0], + client_interaction_layer.server_addr, + ack[1]) + model_loaded = elastic_loader.load(model=model) + if model_loaded is None: + logger.error("Failed to load model") + return None + logger.info("Finish elastic load (duration: {}s)".format( + time.perf_counter() - t0)) + return model_loaded + except Exception as e: + logger.error(f"elastic_load error: {e}") + return None diff --git a/vllm_ascend/model_loader/netloader/netloader.py b/vllm_ascend/model_loader/netloader/netloader.py new file mode 100644 index 00000000..9c2d8307 --- /dev/null +++ b/vllm_ascend/model_loader/netloader/netloader.py @@ -0,0 +1,324 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import gc +import json +import time +from copy import deepcopy +from typing import List, Optional, Tuple + +import torch +from torch import nn +from vllm.config import LoadConfig, ModelConfig, VllmConfig +from vllm.logger import logger +from vllm.model_executor.model_loader import register_model_loader +from vllm.model_executor.model_loader.base_loader import BaseModelLoader +from vllm.model_executor.model_loader.default_loader import DefaultModelLoader +from vllm.model_executor.model_loader.utils import ( + initialize_model, process_weights_after_loading, set_default_torch_dtype) + +from .interaction.elastic import ElasticServer +from .load import elastic_load +from .utils import find_free_port, is_valid_path_prefix + + +@register_model_loader("netloader") +class ModelNetLoaderElastic(BaseModelLoader): + """ + A model loader that uses elastic loading for loading weights. + """ + source: Optional[List[dict]] + model_path: Optional[str] + listen_port: Optional[int] + int8_cache: str + int8_cache_name: Optional[List[str]] + output_prefix: Optional[str] + + def __init__(self, load_config: LoadConfig): + """ + Initializes the ModelNetLoaderElastic with configuration. + + Parameters: + - load_config: Configuration for loading the model. + """ + super().__init__(load_config) + + config = None + + # Try to read config file at first + extra = load_config.model_loader_extra_config + if extra and "CONFIG_FILE" in extra: + try: + logger.info( + f"Reading configs in file {load_config.model_loader_extra_config['CONFIG_FILE']} ..." + ) + with open(extra["CONFIG_FILE"], 'r') as f: + config = json.load(f) + except FileNotFoundError: + logger.error("CONFIG_FILE not found") + except json.JSONDecodeError: + logger.error("CONFIG_FILE is not a valid JSON file") + except Exception as e: + logger.error( + f"Unexpected error while reading CONFIG_FILE: {e}") + + if config is None and extra: + logger.info("Reading configs in model_loader_extra_config ...") + config = extra + config = config or {} + + for key, attr, checker, caster, default in [ + ("SOURCE", "source", lambda v: isinstance(v, list), lambda v: v, + None), + ("MODEL", "model_path", lambda v: isinstance(v, str), lambda v: v, + None), + ("LISTEN_PORT", "listen_port", lambda v: isinstance(v, int) or + (isinstance(v, str) and v.isdigit()), lambda v: int(v), None), + ("INT8_CACHE", "int8_cache", lambda v: isinstance(v, str) and v. + lower() in ['hbm', 'dram', 'no'], lambda v: v.lower(), 'no'), + ("INT8_CACHE_NAME", "int8_cache_name", + lambda v: isinstance(v, list), lambda v: v, None), + ("OUTPUT_PREFIX", "output_prefix", + lambda v: isinstance(v, str) and is_valid_path_prefix(v), + lambda v: v, None), + ]: + v = config.get(key, default) + if not checker(v): + v = default + else: + v = caster(v) + setattr(self, attr, v) + + logger.info( + "Initializing elastic Netloader with config: " + "MODEL=%s, LISTEN_PORT=%s," + "SOURCE=%s, INT8_CACHE=%s, INT8_CACHE_NAME=%s," + "OUTPUT_PREFIX=%s)", + self.model_path, + self.listen_port, + self.source, + self.int8_cache, + self.int8_cache_name, + self.output_prefix, + ) + + def load_model(self, vllm_config: VllmConfig, + model_config: ModelConfig) -> nn.Module: + """ + Loads the model using the specified configuration. + + Parameters: + - vllm_config: Configuration for the VLLM. + - model_config: Configuration for the model. + + Returns: + - The loaded model. + """ + + device_config = vllm_config.device_config + parallel_config = vllm_config.parallel_config + + need_process_weights_after_loading = False + + if self.model_path is None: + self.model_path = model_config.model + logger.info(f"model_path is set to {self.model_path}") + + device_id = torch.distributed.get_rank() + + if (self.source is None or not isinstance(self.source, list) + or device_id not in [ + one_device["device_id"] for one_device in self.source if + isinstance(one_device, dict) and "device_id" in one_device + ]): + logger.warning( + "Did not get valid source info, use DefaultModelLoader") + model, need_process_weights_after_loading = self.revert_to_default( + model_config, vllm_config, device_config) + + else: + target_device = torch.device(device_config.device) + + vllm_config_backup = deepcopy(vllm_config) + model_config_backup = deepcopy(model_config) + + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) + + start_elastic_load = time.perf_counter() + model = elastic_load( + model=model, + device_id=device_id, + model_path=self.model_path, + sources=self.source, + tp=parallel_config.tensor_parallel_size, + pp=parallel_config.pipeline_parallel_size, + ) + end_elastic_load = time.perf_counter() + logger.info( + f"Elastic load time: {end_elastic_load - start_elastic_load}, rank: {device_id}" + ) + need_process_weights_after_loading = True + + if model is None: + logger.warning( + "Netloader elastic loading fails, use load format DefaultModelLoader" + ) + + vllm_config = vllm_config_backup + model_config = model_config_backup + + del model + gc.collect() + if device_config.device_type == 'npu': + logger.info("Empty NPU cache") + torch.npu.empty_cache() + elif device_config.device_type == 'cuda': + logger.info("Empty CUDA cache") + torch.cuda.empty_cache() + + model, need_process_weights_after_loading = self.revert_to_default( + model_config, vllm_config, device_config) + + start_elastic_server = time.perf_counter() + # start elastic server + if model is not None and ( + (self.listen_port and self.listen_port in range(1024, 65535)) or + (self.listen_port is None)): + from vllm.utils import get_ip + driver_ip = get_ip() + + if driver_ip == '0.0.0.0': + logger.error( + "Driver IP is not set, skip to start Netloader server") + else: + if self.listen_port is None: + self.listen_port = find_free_port() + else: + self.listen_port += device_id + + logger.info( + f"Start elastic Netloader server, rank: {device_id}, listen port: {driver_ip}:{self.listen_port}" + ) + + if self.output_prefix is not None: + try: + with open(self.output_prefix + str(device_id) + '.txt', + 'w') as file: + file.write(f"{driver_ip}:{self.listen_port}") + logger.info( + f"Successfully wrote server address to file: {self.output_prefix + str(device_id)}" + ) + except FileNotFoundError: + logger.error( + f"File path {self.output_prefix + str(device_id)} does not exist." + ) + except PermissionError: + logger.error( + f"No permission to write to file {self.output_prefix + str(device_id)}." + ) + except IOError as e: + logger.error( + f"I/O error occurred while writing to file {self.output_prefix + str(device_id)}: {e}" + ) + except Exception as e: + logger.error(f"Unknown error: {e}") + + try: + assert isinstance( + self.listen_port, int + ), f"listen port should be int but get {self.listen_port}" + + elastic_server = ElasticServer( + driver_ip, self.listen_port, model, device_id, + self.model_path, parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size, + self.int8_cache, self.int8_cache_name) + elastic_server.start() + except Exception as e: + logger.error( + f"Failed to start Netloader server for rank: {device_id}, details: {e}" + ) + else: + logger.info("Skip to start Netloader server") + + end_elastic_server = time.perf_counter() + logger.info( + f"Elastic server start time: {end_elastic_server - start_elastic_server}, rank: {device_id}" + ) + + if need_process_weights_after_loading: + process_weights_after_loading(model, model_config, + torch.device(device_config.device)) + + if model is None: + logger.error("NetLoader elastic loads model fails") + return None + + return model.eval() + + def revert_to_default(self, model_config, vllm_config, + device_config) -> Tuple[nn.Module, bool]: + """ + Reverts to the default model loading logic when elastic loading fails or is not applicable. + + This method resets the loader's extra config and load format to defaults, + then delegates model loading to a DefaultModelLoader. + If quantization is enabled, it will load the model and then run the + processing of weights (i.e. applying quantization adjustments) before returning. + + Parameters: + - model_config: Configuration describing model architecture, quantization, etc. + - vllm_config: Configuration for vLLM (device, parallelism, dtype, etc). + - device_config: Configuration for the target device (device type, device id, etc). + + Returns: + - A tuple (model, need_process_weights_after_loading): + * model: The loaded `nn.Module` under default loading logic. + * need_process_weights_after_loading: A boolean flag indicating whether + weights post-processing (e.g. quantization adjustments) still needs to be applied. + """ + self.load_config.model_loader_extra_config = {} + self.load_config.load_format = "auto" + default_model_loader = DefaultModelLoader(self.load_config) + + if model_config.quantization is None: + model = default_model_loader.load_model(vllm_config=vllm_config, + model_config=model_config) + need_process_weights_after_loading = False + else: + logger.warning( + "Quantization is set, netloader use DefaultModelLoader with process_weights_after_loading " + ) + need_process_weights_after_loading = True + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model = initialize_model(vllm_config=vllm_config, + model_config=model_config) + default_model_loader.load_weights(model, model_config) + model = model.eval() + + return model, need_process_weights_after_loading + + def download_model(self, model_config: ModelConfig) -> None: + pass + + def load_weights(self, model: nn.Module, + model_config: ModelConfig) -> None: + pass diff --git a/vllm_ascend/model_loader/netloader/utils.py b/vllm_ascend/model_loader/netloader/utils.py new file mode 100644 index 00000000..fba5a58c --- /dev/null +++ b/vllm_ascend/model_loader/netloader/utils.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import socket + +from vllm.logger import logger + + +def find_free_port(): + """ + Finds a free port on the local machine. + + Returns: + - A free port number. + """ + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + + +def is_valid_path_prefix(path_prefix): + """ + Checks if the provided path prefix is valid. + + Parameters: + - path_prefix: The path prefix to validate. + + Returns: + - True if the path prefix is valid, otherwise False. + """ + if not path_prefix: + return False + + if re.search(r'[<>:"|?*]', path_prefix): + logger.warning( + f'The path prefix {path_prefix} contains illegal characters.') + return False + + if path_prefix.startswith('/') or path_prefix.startswith('\\'): + if not os.path.exists(os.path.dirname(path_prefix)): + logger.warning( + f'The directory for the path prefix {os.path.dirname(path_prefix)} does not exist.' + ) + return False + else: + if not os.path.exists(os.path.dirname(os.path.abspath(path_prefix))): + logger.warning( + f'The directory for the path prefix {os.path.dirname(os.path.abspath(path_prefix))} does not exist.' + ) + return False + return True \ No newline at end of file