Files
xc-llm-ascend/tests/ut/model_loader/netloader/test_netloader_elastic.py
zzzzwwjj 71f729a661 Revert "moe_gating_top_k" (#5512)
Reverts vllm-project/vllm-ascend#5271

It breaks e2e test

- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
2025-12-30 15:05:47 +08:00

433 lines
14 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import io
import json
import logging
import socket
from unittest.mock import MagicMock, patch
import pytest
import torch
import vllm.logger
from vllm_ascend.model_loader.netloader.interaction.elastic import (
ElasticClient, ElasticServer)
# Simulate server's normal response
def mock_server_response(data):
return json.dumps({
"label": "JOIN_ACK",
"content": {
"name": "mocked_name"
}
}).encode("utf-8")
# Simulate server's error response
def mock_server_error_response(data):
return json.dumps({"label": "JOIN_ACK", "content": None}).encode("utf-8")
# Simulated server's abnormal response
def mock_server_exception_response(data):
raise Exception("Mocked server exception")
# Test the initialization of ElasticClient
def test_elastic_client_init():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.recv.return_value = mock_server_response(None)
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
mock_socket_instance.__enter__.return_value = mock_socket_instance
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
assert client.server_addr == "127.0.0.1"
assert client.server_port == 12345
assert client.ack == ("mocked_name", 12346)
mock_socket_instance.close.assert_called_once()
# Test the register method of ElasticClient
def test_elastic_client_register():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.return_value = mock_server_response(None)
mock_socket_instance.getsockname.return_value = ('127.0.0.1', 12346)
mock_socket_instance.__enter__.return_value = mock_socket_instance
client = ElasticClient(sources, device_id, model_path, tp, pp)
assert client.register(device_id, model_path, tp,
pp) == ("mocked_name", 12346)
# Test the behavior of the `register` method of ElasticClient when the server returns an error response.
def test_elastic_client_register_error_response():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.return_value = mock_server_error_response(
None)
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
with pytest.raises(RuntimeError):
client.register(device_id, model_path, tp, pp)
mock_socket_instance.close.assert_called_once()
# Test the behavior of the `register` method of ElasticClient when an exception is thrown on the server.
def test_elastic_client_register_exception():
sources = ["127.0.0.1:12345"]
device_id = 0
model_path = "mocked_model_path"
tp = 1
pp = 1
with patch('socket.socket') as mock_socket:
mock_socket_instance = MagicMock()
mock_socket.return_value = mock_socket_instance
mock_socket_instance.connect.return_value = None
mock_socket_instance.recv.side_effect = mock_server_exception_response
mock_socket_instance.__enter__.return_value = mock_socket_instance
mock_socket_instance.__exit__.return_value = None
with ElasticClient(sources, device_id, model_path, tp, pp) as client:
with pytest.raises(RuntimeError):
client.register(device_id, model_path, tp, pp)
mock_socket_instance.close.assert_called_once()
class FakeInt8Param:
def __init__(self, name="param", device="npu", dtype=torch.int8):
self.dtype = dtype
self.device = torch.device(device)
@property
def data(self):
return self # Simulate .data returning self so .cpu() etc. can be chained
def clone(self):
return self
def detach(self):
return self
def cpu(self):
self.device = torch.device("cpu")
return self
class FakeModel:
def __init__(self):
self.params = {
"param1": MagicMock(dtype=torch.float32), # This will be ignored
"param2": FakeInt8Param(), # This simulates a real int8 param
}
def named_parameters(self):
return self.params.items()
@pytest.fixture
def mock_model():
return FakeModel()
@pytest.fixture
def server_config():
return {
"addr": "127.0.0.1",
"port": 8080,
"model": MagicMock(),
"device_id": 0,
"model_path": "/test/model",
"tp": 1,
"pp": 1,
"int8_cache": "dram",
'int8_cache_name': None
}
# Test server initialization
def test_server_initialization(server_config, mock_model):
server_config["model"] = mock_model
with patch("socket.socket") as mock_socket:
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
server = ElasticServer(**server_config)
# Check the socket configuration
mock_socket.assert_called_with(socket.AF_INET, socket.SOCK_STREAM)
mock_socket.return_value.setsockopt.assert_called_with(
socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
mock_socket.return_value.bind.assert_called_with(("127.0.0.1", 8080))
mock_socket.return_value.listen.assert_called_with(256)
# Check int8 cache
assert "param2" in server.original_int8
assert server.original_int8[
"param2"].device.type == "cpu" # Verifying DRAM Cache
assert server.addr == server_config['addr']
assert server.port == server_config['port']
assert server.device_id == server_config['device_id']
assert server.model_path == server_config['model_path']
assert server.tp == server_config['tp']
assert server.pp == server_config['pp']
# Get captured logs
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
# Check output
assert "Server 127.0.0.1:8080 starts" in log_output
# Test the int8 cache option
@pytest.mark.parametrize("cache_option,expected_device", [("dram", "cpu"),
("no", None),
("invalid", None)])
def test_int8_cache_handling(server_config, mock_model, cache_option,
expected_device, caplog):
server_config["int8_cache"] = cache_option
server_config["model"] = mock_model
with patch("socket.socket"):
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
server = ElasticServer(**server_config)
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
if cache_option == "invalid":
assert "int8_cache should be selected in [HBM, DRAM]" in log_output
if expected_device is None:
assert len(server.original_int8) == 0
else:
assert server.original_int8[
"param2"].device.type == expected_device
# Test client processing
def test_client_handler_valid_join(server_config, mock_model):
server_config["model"] = mock_model
with patch("vllm_ascend.model_loader.netloader.interaction.elastic.P2PSend"
) as mock_p2p_send:
# Create a simulated connection
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
# Configuring Client Data
valid_data = {
"label": "JOIN",
"content": {
"device_id": 0,
"model_path": "/test/model",
"tp": 1,
"pp": 1,
"port": 9090
}
}
mock_conn.recv.return_value = json.dumps(valid_data).encode("utf-8")
# Start the server
server = ElasticServer(**server_config)
server.register_handler(mock_conn, mock_addr)
# Verify response
expected_ack = {
"label": "JOIN_ACK",
"content": {
"name": "192.168.1.1:12345"
}
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
mock_p2p_send.assert_called_once_with("127.0.0.1", 9090,
"192.168.1.1:12345")
mock_conn.close.assert_called_once()
# Test mismatched JOIN requests
def test_client_handler_mismatch(server_config):
with patch("socket.socket"):
server = ElasticServer(**server_config)
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
# Send mismatched data
mismatch_data = {
"label": "JOIN",
"content": {
"device_id": 1, # 不匹配的ID
"model_path": "/wrong/model",
"tp": 2,
"pp": 2,
"port": 9090
}
}
mock_conn.recv.return_value = json.dumps(mismatch_data).encode("utf-8")
server.register_handler(mock_conn, mock_addr)
assert isinstance(mismatch_data["content"], dict)
# Verify response
expected_ack = {
"label":
"JOIN_NACK",
"content":
f"Received data {(mismatch_data['content']['device_id'], mismatch_data['content']['model_path'], mismatch_data['content']['tp'], mismatch_data['content']['pp'])} does not consist with this server {(server_config['device_id'], server_config['model_path'], server_config['tp'], server_config['pp'])}"
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
mock_conn.close.assert_called_once()
# Test Invalid Request
@pytest.mark.parametrize(
"invalid_data,should_send",
[
(
{
"label": "WRONG_LABEL"
}, True
), # Incorrect label, can be decoded as JSON, but the content is invalid.
(
{
"content": {
"missing_fields": True
}
}, True
), # Missing field, can be decoded as JSON, but the content is invalid.
("plain text", False), # Non-JSON data, json.loads failed
(b"invalid_bytes", False) # Invalid byte, decode or json.loads failed
])
def test_client_handler_invalid_requests(server_config, invalid_data,
should_send):
with patch("socket.socket"):
log_capture_string = io.StringIO()
ch = logging.StreamHandler(log_capture_string)
ch.setLevel(logging.DEBUG)
vllm.logger.logger.addHandler(ch)
with patch("socket.socket"):
server = ElasticServer(**server_config)
mock_conn = MagicMock()
mock_addr = ("192.168.1.1", 12345)
if isinstance(invalid_data, (str, bytes)):
mock_conn.recv.return_value = invalid_data if isinstance(
invalid_data, bytes) else invalid_data.encode()
else:
mock_conn.recv.return_value = json.dumps(invalid_data).encode(
"utf-8")
server.register_handler(mock_conn, mock_addr)
if should_send:
expected_ack = {
"label":
"JOIN_NACK",
"content":
f"Received data does not contain required fields: {invalid_data}"
}
mock_conn.send.assert_called_once_with(
json.dumps(expected_ack).encode("utf-8"))
else:
mock_conn.send.assert_not_called()
log_output = log_capture_string.getvalue()
vllm.logger.logger.removeHandler(ch)
log_capture_string.close()
# Any warning in the log is acceptable
assert "Failed to load" in log_output or "does not contain" in log_output
mock_conn.close.assert_called_once()
# Test the thread startup.
def test_server_start(server_config):
with patch("socket.socket"), \
patch("threading.Thread") as mock_thread:
handler_thread_instance = mock_thread.return_value
server = ElasticServer(**server_config)
server.start()
# Assert that the correct target parameter was passed when instantiating the Thread instance.
mock_thread.assert_called_once()
args, kwargs = mock_thread.call_args
assert kwargs['target'] == server.elastic_client_handler
# Check that the daemon attribute is set to True (the attribute value will be recorded after MagicMock assignment).
assert handler_thread_instance.daemon is True
# Check if the start() method is called.
handler_thread_instance.start.assert_called_once()
# Test resource clearing
def test_server_cleanup(server_config):
with patch("socket.socket") as mock_socket:
server = ElasticServer(**server_config)
del server
mock_socket.return_value.close.assert_called_once()
if __name__ == "__main__":
pytest.main()