216 lines
6.1 KiB
Python
216 lines
6.1 KiB
Python
|
|
#
|
||
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
|
|
#
|
||
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
|
# you may not use this file except in compliance with the License.
|
||
|
|
# You may obtain a copy of the License at
|
||
|
|
#
|
||
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
|
#
|
||
|
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
|
# See the License for the specific language governing permissions and
|
||
|
|
# limitations under the License.
|
||
|
|
#
|
||
|
|
|
||
|
|
import json
|
||
|
|
from unittest.mock import MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
|
||
|
|
from vllm_ascend.model_loader.netloader.netloader import ModelNetLoaderElastic
|
||
|
|
|
||
|
|
|
||
|
|
class DummyDeviceConfig:
|
||
|
|
device = 'cuda'
|
||
|
|
device_type = 'cuda'
|
||
|
|
|
||
|
|
|
||
|
|
class DummyParallelConfig:
|
||
|
|
tensor_parallel_size = 1
|
||
|
|
pipeline_parallel_size = 1
|
||
|
|
|
||
|
|
|
||
|
|
class DummyVllmConfig:
|
||
|
|
device_config = DummyDeviceConfig()
|
||
|
|
parallel_config = DummyParallelConfig()
|
||
|
|
additional_config = None
|
||
|
|
|
||
|
|
|
||
|
|
class DummyModelConfig:
|
||
|
|
model = 'dummy-model'
|
||
|
|
dtype = torch.float32
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def default_load_config():
|
||
|
|
|
||
|
|
class DummyLoadConfig:
|
||
|
|
model_loader_extra_config = None
|
||
|
|
load_format = "default"
|
||
|
|
|
||
|
|
return DummyLoadConfig()
|
||
|
|
|
||
|
|
|
||
|
|
def make_loader_with_config(extra):
|
||
|
|
|
||
|
|
class DummyLoadConfig:
|
||
|
|
model_loader_extra_config = extra
|
||
|
|
load_format = "default"
|
||
|
|
|
||
|
|
return ModelNetLoaderElastic(DummyLoadConfig())
|
||
|
|
|
||
|
|
|
||
|
|
def test_init_with_extra_config_file(tmp_path, monkeypatch):
|
||
|
|
# Generate test JSON file
|
||
|
|
config_content = {
|
||
|
|
"SOURCE": [{
|
||
|
|
"device_id": 0
|
||
|
|
}],
|
||
|
|
"MODEL": "foo-model",
|
||
|
|
"LISTEN_PORT": 5001,
|
||
|
|
"INT8_CACHE": "hbm",
|
||
|
|
"OUTPUT_PREFIX": str(tmp_path),
|
||
|
|
}
|
||
|
|
config_file = tmp_path / "config.json"
|
||
|
|
config_file.write_text(json.dumps(config_content))
|
||
|
|
|
||
|
|
dummy_logger = MagicMock()
|
||
|
|
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
|
||
|
|
lambda x: True)
|
||
|
|
|
||
|
|
extra = {"CONFIG_FILE": str(config_file)}
|
||
|
|
loader = make_loader_with_config(extra)
|
||
|
|
assert loader.model_path == "foo-model"
|
||
|
|
assert loader.source == [{"device_id": 0}]
|
||
|
|
assert loader.listen_port == 5001
|
||
|
|
assert loader.int8_cache == "hbm"
|
||
|
|
assert loader.output_prefix == str(tmp_path)
|
||
|
|
|
||
|
|
|
||
|
|
def test_init_with_extra_config(monkeypatch):
|
||
|
|
dummy_logger = MagicMock()
|
||
|
|
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
|
||
|
|
lambda x: True)
|
||
|
|
|
||
|
|
extra = {
|
||
|
|
"SOURCE": [{
|
||
|
|
"device_id": 0
|
||
|
|
}],
|
||
|
|
"MODEL": "foo",
|
||
|
|
"LISTEN_PORT": "4000",
|
||
|
|
"INT8_CACHE": "dram",
|
||
|
|
"OUTPUT_PREFIX": "/tmp/"
|
||
|
|
}
|
||
|
|
loader = make_loader_with_config(extra)
|
||
|
|
assert loader.model_path == "foo"
|
||
|
|
assert loader.listen_port == 4000
|
||
|
|
assert loader.int8_cache == "dram"
|
||
|
|
assert loader.output_prefix == "/tmp/"
|
||
|
|
assert loader.source == [{"device_id": 0}]
|
||
|
|
|
||
|
|
|
||
|
|
def test_init_with_invalid_config(monkeypatch):
|
||
|
|
dummy_logger = MagicMock()
|
||
|
|
monkeypatch.setattr("vllm.logger.logger", dummy_logger)
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.utils.is_valid_path_prefix",
|
||
|
|
lambda x: False)
|
||
|
|
# c
|
||
|
|
extra = {
|
||
|
|
"SOURCE": None,
|
||
|
|
"MODEL": None,
|
||
|
|
"LISTEN_PORT": None,
|
||
|
|
"INT8_CACHE": "something",
|
||
|
|
"OUTPUT_PREFIX": None,
|
||
|
|
}
|
||
|
|
loader = make_loader_with_config(extra)
|
||
|
|
assert loader.model_path is None
|
||
|
|
assert loader.listen_port is None
|
||
|
|
assert loader.int8_cache == "no"
|
||
|
|
assert loader.output_prefix is None
|
||
|
|
|
||
|
|
|
||
|
|
@patch("vllm_ascend.model_loader.netloader.netloader.logger")
|
||
|
|
def test_load_model_elastic_success(mock_logger, monkeypatch, tmp_path):
|
||
|
|
monkeypatch.setattr("torch.distributed.get_rank", lambda: 0)
|
||
|
|
|
||
|
|
class FakeContext:
|
||
|
|
|
||
|
|
def __enter__(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def __exit__(self, a, b, c):
|
||
|
|
pass
|
||
|
|
|
||
|
|
monkeypatch.setattr("torch.device", lambda d: FakeContext())
|
||
|
|
# patch deep copy
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.deepcopy", lambda x: x)
|
||
|
|
# patch set_default_torch_dtype
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.set_default_torch_dtype",
|
||
|
|
lambda dtype: FakeContext())
|
||
|
|
# patch initialize_model
|
||
|
|
dummy_model = MagicMock(spec=nn.Module)
|
||
|
|
dummy_model.eval.return_value = dummy_model
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.initialize_model",
|
||
|
|
lambda **kwargs: dummy_model)
|
||
|
|
# patch elastic_load
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.elastic_load",
|
||
|
|
lambda **kwargs: dummy_model)
|
||
|
|
# patch process_weights_after_loading
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.process_weights_after_loading",
|
||
|
|
lambda *a, **k: None)
|
||
|
|
# patch get_ip
|
||
|
|
monkeypatch.setattr("vllm.utils.get_ip", lambda: "127.0.0.1")
|
||
|
|
# patch find_free_port
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.find_free_port",
|
||
|
|
lambda: 8888)
|
||
|
|
|
||
|
|
# patch ElasticServer
|
||
|
|
class DummyElasticServer:
|
||
|
|
|
||
|
|
def __init__(*a, **k):
|
||
|
|
pass
|
||
|
|
|
||
|
|
def start(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
monkeypatch.setattr(
|
||
|
|
"vllm_ascend.model_loader.netloader.netloader.ElasticServer",
|
||
|
|
DummyElasticServer)
|
||
|
|
# write output_prefix to the temporary directory
|
||
|
|
extra = {
|
||
|
|
"SOURCE": [{
|
||
|
|
"device_id": 0
|
||
|
|
}],
|
||
|
|
"MODEL": "foo",
|
||
|
|
"LISTEN_PORT": 5555,
|
||
|
|
"OUTPUT_PREFIX": str(tmp_path) + "/output_",
|
||
|
|
"INT8_CACHE": "no"
|
||
|
|
}
|
||
|
|
loader = make_loader_with_config(extra)
|
||
|
|
vllm_config = DummyVllmConfig()
|
||
|
|
model_config = DummyModelConfig()
|
||
|
|
result = loader.load_model(vllm_config, model_config)
|
||
|
|
assert isinstance(result, nn.Module)
|
||
|
|
# Check file
|
||
|
|
written_file = tmp_path / "output_0.txt"
|
||
|
|
assert written_file.exists()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main()
|