[Misc] Add a model loader that utilizes HCCL for weight loading (#2888)

### What this PR does / why we need it?

This PR introduces a new model loader called Netloader, which leverages
high-bandwidth P2P direct transfer between NPU cards to achieve weight
loading. Netloader is implemented as a plugin through the newly added
'register_model_loader' function in vLLM 0.10. It facilitates the
process of weight loading by sending weights from a pre-loaded model
(server) to an empty model of a newly started instance (client). The
server operates concurrently with normal inference tasks through
sub-threads and the 'stateless_init_torch_distributed_process_group' in
vLLM. The client initiates a transfer request after verifying that the
model and partitioning method are the same as the server's, and uses
HCCL's collective communication (send/recv) to load the weights in the
order they are stored in the model.

Application Scenarios:
1. Significantly Reduces Inference Instance Startup Time By reusing the
weights of already loaded instances and performing high-speed transfers
directly between computing cards, this method reduces model loading
latency compared to traditional remote/local pull methods.
2. Reduces Network and Storage Pressure Avoids the need to repeatedly
download weight files from remote repositories, reducing the impact on
centralized storage and network traffic, thereby enhancing overall
system stability and service quality.
3. Improves Resource Utilization and Reduces Costs Accelerating the
loading process reduces reliance on redundant computing pools, allowing
computing resources to be elastically scaled and reclaimed as needed.
4. Enhances Business Continuity and High Availability In fault recovery
scenarios, new instances can quickly take over existing services,
avoiding prolonged business interruptions and improving the system's
high availability and user experience.

### Does this PR introduce _any_ user-facing change?

Netloader utilizes the existing --load-format=netloader and
--model-loader-extra-config to be activated. The
model-loader-extra-config needs to be input as a JSON string (as it is
now)

Afterwards, you can check whether the outputs for the same sentence are
consistent when the temperature is set to 0.

Signed-off-by: destinysky <kangrui10@126.com>

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: destinysky <kangrui10@126.com>
This commit is contained in:
Rui Kang
2025-10-23 15:56:07 +08:00
committed by GitHub
parent 807686dec9
commit 427b17e2da
19 changed files with 1999 additions and 1 deletions

View File

@@ -0,0 +1,215 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import json
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch import nn
from vllm_ascend.model_loader.netloader.netloader import ModelNetLoaderElastic
class DummyDeviceConfig:
device = 'cuda'
device_type = 'cuda'
class DummyParallelConfig:
tensor_parallel_size = 1
pipeline_parallel_size = 1
class DummyVllmConfig:
device_config = DummyDeviceConfig()
parallel_config = DummyParallelConfig()
additional_config = None
class DummyModelConfig:
model = 'dummy-model'
dtype = torch.float32
@pytest.fixture
def default_load_config():
class DummyLoadConfig:
model_loader_extra_config = None
load_format = "default"
return DummyLoadConfig()
def make_loader_with_config(extra):
class DummyLoadConfig:
model_loader_extra_config = extra
load_format = "default"
return ModelNetLoaderElastic(DummyLoadConfig())
def test_init_with_extra_config_file(tmp_path, monkeypatch):
# Generate test JSON file
config_content = {
"SOURCE": [{
"device_id": 0
}],
"MODEL": "foo-model",
"LISTEN_PORT": 5001,
"INT8_CACHE": "hbm",
"OUTPUT_PREFIX": str(tmp_path),
}
config_file = tmp_path / "config.json"
config_file.write_text(json.dumps(config_content))
dummy_logger = MagicMock()
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
lambda x: True)
extra = {"CONFIG_FILE": str(config_file)}
loader = make_loader_with_config(extra)
assert loader.model_path == "foo-model"
assert loader.source == [{"device_id": 0}]
assert loader.listen_port == 5001
assert loader.int8_cache == "hbm"
assert loader.output_prefix == str(tmp_path)
def test_init_with_extra_config(monkeypatch):
dummy_logger = MagicMock()
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
lambda x: True)
extra = {
"SOURCE": [{
"device_id": 0
}],
"MODEL": "foo",
"LISTEN_PORT": "4000",
"INT8_CACHE": "dram",
"OUTPUT_PREFIX": "/tmp/"
}
loader = make_loader_with_config(extra)
assert loader.model_path == "foo"
assert loader.listen_port == 4000
assert loader.int8_cache == "dram"
assert loader.output_prefix == "/tmp/"
assert loader.source == [{"device_id": 0}]
def test_init_with_invalid_config(monkeypatch):
dummy_logger = MagicMock()
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
lambda x: False)
# c
extra = {
"SOURCE": None,
"MODEL": None,
"LISTEN_PORT": None,
"INT8_CACHE": "something",
"OUTPUT_PREFIX": None,
}
loader = make_loader_with_config(extra)
assert loader.model_path is None
assert loader.listen_port is None
assert loader.int8_cache == "no"
assert loader.output_prefix is None
@patch("vllm_ascend.model_loader.netloader.netloader.logger")
def test_load_model_elastic_success(mock_logger, monkeypatch, tmp_path):
monkeypatch.setattr("torch.distributed.get_rank", lambda: 0)
class FakeContext:
def __enter__(self):
pass
def __exit__(self, a, b, c):
pass
monkeypatch.setattr("torch.device", lambda d: FakeContext())
# patch deep copy
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x)
# patch set_default_torch_dtype
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype",
lambda dtype: FakeContext())
# patch initialize_model
dummy_model = MagicMock(spec=nn.Module)
dummy_model.eval.return_value = dummy_model
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.initialize_model",
lambda **kwargs: dummy_model)
# patch elastic_load
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.elastic_load",
lambda **kwargs: dummy_model)
# patch process_weights_after_loading
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading",
lambda *a, **k: None)
# patch get_ip
monkeypatch.setattr("vllm.utils.get_ip", lambda: "127.0.0.1")
# patch find_free_port
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.find_free_port",
lambda: 8888)
# patch ElasticServer
class DummyElasticServer:
def __init__(*a, **k):
pass
def start(self):
pass
monkeypatch.setattr(
"vllm_ascend.model_loader.netloader.netloader.ElasticServer",
DummyElasticServer)
# write output_prefix to the temporary directory
extra = {
"SOURCE": [{
"device_id": 0
}],
"MODEL": "foo",
"LISTEN_PORT": 5555,
"OUTPUT_PREFIX": str(tmp_path) + "/output_",
"INT8_CACHE": "no"
}
loader = make_loader_with_config(extra)
vllm_config = DummyVllmConfig()
model_config = DummyModelConfig()
result = loader.load_model(vllm_config, model_config)
assert isinstance(result, nn.Module)
# Check file
written_file = tmp_path / "output_0.txt"
assert written_file.exists()
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,432 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import json
import logging
import socket
from unittest.mock import MagicMock, patch
import pytest
import torch
import vllm.logger
from vllm_ascend.model_loader.netloader.interaction.elastic import (
ElasticClient, ElasticServer)
# Simulate server's normal response
def mock_server_response(data):
return json.dumps({
"label": "JOIN_ACK",
"content": {
"name": "mocked_name"
}
}).encode("utf-8")
# Simulate server's error response
def mock_server_error_response(data):
return json.dumps({"label": "JOIN_ACK", "content": None}).encode("utf-8")
# Simulated server's abnormal response
def mock_server_exception_response(data):
raise Exception("Mocked server exception")
# Test the initialization of ElasticClient
def test_elastic_client_init():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.recv.return_value = mock_server_response(None)
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
mock_socket_instance.__enter__.return_value = mock_socket_instance
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
assert client.server_addr == "127.0.0.1"
assert client.server_port == 12345
assert client.ack == ("mocked_name", 12346)
mock_socket_instance.close.assert_called_once()
# Test the register method of ElasticClient
def test_elastic_client_register():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.return_value = mock_server_response(None)
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
mock_socket_instance.__enter__.return_value = mock_socket_instance
client = ElasticClient(sources, device_id, model_path, tp, pp)
assert client.register(device_id, model_path, tp,
pp) == ("mocked_name", 12346)
# Test the behavior of the `register` method of ElasticClient when the server returns an error response.
def test_elastic_client_register_error_response():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.return_value = mock_server_error_response(
None)
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
with pytest.raises(RuntimeError):
client.register(device_id, model_path, tp, pp)
mock_socket_instance.close.assert_called_once()
# Test the behavior of the `register` method of ElasticClient when an exception is thrown on the server.
def test_elastic_client_register_exception():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.side_effect = mock_server_exception_response
mock_socket_instance.__enter__.return_value = mock_socket_instance
mock_socket_instance.__exit__.return_value = None
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
with pytest.raises(RuntimeError):
client.register(device_id, model_path, tp, pp)
mock_socket_instance.close.assert_called_once()
class FakeInt8Param:
def __init__(self, name="param", device="npu", dtype=torch.int8):
self.dtype = dtype
self.device = torch.device(device)
@property
def data(self):
return self # Simulate .data returning self so .cpu() etc. can be chained
def clone(self):
return self
def detach(self):
return self
def cpu(self):
self.device = torch.device("cpu")
return self
class FakeModel:
def __init__(self):
self.params = {
"param1": MagicMock(dtype=torch.float32), # This will be ignored
"param2": FakeInt8Param(), # This simulates a real int8 param
}
def named_parameters(self):
return self.params.items()
@pytest.fixture
def mock_model():
return FakeModel()
@pytest.fixture
def server_config():
return {
"addr": "127.0.0.1",
"port": 8080,
"model": MagicMock(),
"device_id": 0,
"model_path": "/test/model",
"tp": 1,
"pp": 1,
"int8_cache": "dram",
'int8_cache_name': None
}
# Test server initialization
def test_server_initialization(server_config, mock_model):
server_config["model"] = mock_model
with patch("socket.socket") as mock_socket:
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
server = ElasticServer(**server_config)
# Check the socket configuration
mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM)
mock_socket.return_value.setsockopt.assert_called_with(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080))
mock_socket.return_value.listen.assert_called_with(256)
# Check int8 cache
assert "param2" in server.original_int8
assert server.original_int8[
"param2"].device.type == "cpu" # Verifying DRAM Cache
assert server.addr == server_config['addr']
assert server.port == server_config['port']
assert server.device_id == server_config['device_id']
assert server.model_path == server_config['model_path']
assert server.tp == server_config['tp']
assert server.pp == server_config['pp']
# Get captured logs
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
# Check output
assert "Server 127.0.0.1:8080 starts" in log_output
# Test the int8 cache option
@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"),
("no", None),
("invalid", None)])
def test_int8_cache_handling(server_config, mock_model, cache_option,
expected_device, caplog):
server_config["int8_cache"] = cache_option
server_config["model"] = mock_model
with patch("socket.socket"):
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
server = ElasticServer(**server_config)
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
if cache_option == "invalid":
assert "int8_cache should be selected in [HBM, DRAM]" in log_output
if expected_device is None:
assert len(server.original_int8) == 0
else:
assert server.original_int8[
"param2"].device.type == expected_device
# Test client processing
def test_client_handler_valid_join(server_config, mock_model):
server_config["model"] = mock_model
with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend"
) as mock_p2p_send:
# Create a simulated connection
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
# Configuring Client Data
valid_data = {
"label": "JOIN",
"content": {
"device_id": 0,
"model_path": "/test/model",
"tp": 1,
"pp": 1,
"port": 9090
}
}
mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8")
# Start the server
server = ElasticServer(**server_config)
server.register_handler(mock_conn, mock_addr)
# Verify response
expected_ack = {
"label": "JOIN_ACK",
"content": {
"name": "192.168.1.1:12345"
}
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
mock_p2p_send.assert_called_once_with("127.0.0.1", 9090,
"192.168.1.1:12345")
mock_conn.close.assert_called_once()
# Test mismatched JOIN requests
def test_client_handler_mismatch(server_config):
with patch("socket.socket"):
server = ElasticServer(**server_config)
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
# Send mismatched data
mismatch_data = {
"label": "JOIN",
"content": {
"device_id": 1, # 不匹配的ID
"model_path": "/wrong/model",
"tp": 2,
"pp": 2,
"port": 9090
}
}
mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8")
server.register_handler(mock_conn, mock_addr)
assert isinstance(mismatch_data["content"], dict)
# Verify response
expected_ack = {
"label":
"JOIN_NACK",
"content":
f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}"
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
mock_conn.close.assert_called_once()
# Test Invalid Request
@pytest.mark.parametrize(
"invalid_data,should_send",
[
(
{
"label": "WRONG_LABEL"
}, True
), # Incorrect label, can be decoded as JSON, but the content is invalid.
(
{
"content": {
"missing_fields": True
}
}, True
), # Missing field, can be decoded as JSON, but the content is invalid.
("plain text", False), # Non-JSON data, json.loads failed
(b"invalid_bytes", False) # Invalid byte, decode or json.loads failed
])
def test_client_handler_invalid_requests(server_config, invalid_data,
should_send):
with patch("socket.socket"):
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
with patch("socket.socket"):
server = ElasticServer(**server_config)
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
if isinstance(invalid_data, (str, bytes)):
mock_conn.recv.return_value = invalid_data if isinstance(
invalid_data, bytes) else invalid_data.encode()
else:
mock_conn.recv.return_value = json.dumps(invalid_data).encode(
"utf-8")
server.register_handler(mock_conn, mock_addr)
if should_send:
expected_ack = {
"label":
"JOIN_NACK",
"content":
f"Received data does not contain required fields: {invalid_data}"
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
else:
mock_conn.send.assert_not_called()
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
# Any warning in the log is acceptable
assert "Failed to load" in log_output or "does not contain" in log_output
mock_conn.close.assert_called_once()
# Test the thread startup.
def test_server_start(server_config):
with patch("socket.socket"), \
patch("threading.Thread") as mock_thread:
handler_thread_instance = mock_thread.return_value
server = ElasticServer(**server_config)
server.start()
# Assert that the correct target parameter was passed when instantiating the Thread instance.
mock_thread.assert_called_once()
args, kwargs = mock_thread.call_args
assert kwargs['target'] == server.elastic_client_handler
# Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment).
assert handler_thread_instance.daemon is True
# Check if the start() method is called.
handler_thread_instance.start.assert_called_once()
# Test resource clearing
def test_server_cleanup(server_config):
with patch("socket.socket") as mock_socket:
server = ElasticServer(**server_config)
del server
mock_socket.return_value.close.assert_called_once()
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,114 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from unittest.mock import MagicMock, patch
import pytest
from vllm_ascend.model_loader.netloader.load import elastic_load
@pytest.fixture
def mock_sources():
return [
{
"device_id": 0,
"sources": ["a", "b"]
},
{
"device_id": 1,
"sources": ["c"]
},
]
@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
def test_sources_this_device_empty(mock_p2p, mock_client):
sources = [{"device_id": 1, "sources": ["c"]}]
result = elastic_load("model", 0, "model_path", sources, 1, 1)
assert result is None
mock_client.assert_not_called()
mock_p2p.assert_not_called()
@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
def test_client_s_none(mock_p2p, mock_client, mock_sources):
# Simulate ElasticClient.s as None
mock_instance = MagicMock()
mock_instance.s = None
mock_client.return_value = mock_instance
result = elastic_load("model", 0, "model_path", mock_sources, 1, 1)
assert result is None
@patch("vllm_ascend.model_loader.netloader.interaction.elastic.ElasticClient")
@patch("vllm_ascend.model_loader.netloader.executor.elastic_load.P2PLoad")
def test_client_ack_none(mock_p2p, mock_client, mock_sources):
# Simulate ElasticClient.ack as None
mock_instance = MagicMock()
mock_instance.s = True
mock_instance.ack = None
mock_client.return_value = mock_instance
result = elastic_load("model", 0, "model_path", mock_sources, 1, 1)
assert result is None
@patch("vllm_ascend.model_loader.netloader.load.P2PLoad")
@patch("vllm_ascend.model_loader.netloader.load.logger")
def test_model_load_fail(mock_logger, mock_p2p):
mock_client = MagicMock()
mock_client.s = True
mock_client.ack = ["foo", "bar"]
mock_client.server_addr = "addr"
with patch("vllm_ascend.model_loader.netloader.load.ElasticClient",
return_value=mock_client):
# P2PLoad.load returns None
mock_p2p_instance = MagicMock()
mock_p2p_instance.load.return_value = None
mock_p2p.return_value = mock_p2p_instance
sources = [{"device_id": 0, "sources": ["whatever"]}]
result = elastic_load("model", 0, "model_path", sources, 1, 1)
assert result is None
mock_logger.error.assert_called_once()
@patch("vllm_ascend.model_loader.netloader.load.P2PLoad")
@patch("vllm_ascend.model_loader.netloader.load.logger")
def test_model_load_success(mock_logger, mock_p2p):
mock_client = MagicMock()
mock_client.s = True
mock_client.ack = ["foo", "bar"]
mock_client.server_addr = "addr"
with patch("vllm_ascend.model_loader.netloader.load.ElasticClient",
return_value=mock_client):
expected_model = object()
mock_p2p_instance = MagicMock()
mock_p2p_instance.load.return_value = expected_model
mock_p2p.return_value = mock_p2p_instance
sources = [{"device_id": 0, "sources": ["whatever"]}]
result = elastic_load("model", 0, "model_path", sources, 1, 1)
assert result is expected_model
mock_logger.info.assert_called_once()
if __name__ == "__main__":
pytest.main()

View File

@@ -0,0 +1,61 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import tempfile
from unittest.mock import patch
import pytest
from vllm_ascend.model_loader.netloader.utils import (find_free_port,
is_valid_path_prefix)
def test_find_free_port():
port = find_free_port()
assert isinstance(port, int)
assert port > 0
def test_is_valid_path_prefix_empty():
assert not is_valid_path_prefix('')
def test_is_valid_path_prefixIllegal_characters():
assert not is_valid_path_prefix('test<>:"|?*')
def test_is_valid_path_prefixRelative_path():
assert is_valid_path_prefix('test')
def test_is_valid_path_prefixAbsolute_path():
with tempfile.TemporaryDirectory() as tmpdir:
assert is_valid_path_prefix(os.path.join(tmpdir, 'test'))
@patch('os.path.exists', return_value=False)
def test_is_valid_path_prefix_no_directory(mock_exists):
assert not is_valid_path_prefix('/nonexistent_dir/test')
@patch('os.path.exists', return_value=True)
def test_is_valid_path_prefix_directory_exists(mock_exists):
assert is_valid_path_prefix('/existing_dir/test')
if __name__ == "__main__":
pytest.main()