Files
xc-llm-ascend/tests/ut/distributed/kv_transfer/test_simple_pipe.py
Agonixiaoxiao 7fc1a98489 add ut for kv tansfer module (#1531)
### What this PR does / why we need it?
test kv data transfer contains connect,pipe,buffer

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added test.

---------

Signed-off-by: lixudong <lixudong@cmss.chinamobile.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: lixudong <lixudong@cmss.chinamobile.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
2025-07-02 16:14:52 +08:00

146 lines
5.5 KiB
Python

import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
class TestSimplePipe(unittest.TestCase):
@classmethod
def _create_mock_config(self):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
mock_config.kv_port = 5500
return mock_config
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_init_success(self, mock_thread, MockLLMDataDist):
mock_config = self._create_mock_config()
self.pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.pipe.router_socket.close()
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
self.pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=self._create_mock_config(),
hostname="127.0.0.1",
port_offset=0)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
self.pipe.router_socket.close()
def test_init_with_invalid_kv_role(self):
with self.assertRaises(NotImplementedError):
mock_config = MagicMock()
mock_config.kv_role = "err_role"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
pipe.router_socket.close()
def test_init_with_missing_device_ips(self):
with self.assertRaises(ValueError):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=0,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
pipe.router_socket.close()
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_create_register_thread_address_is_empty(self, MockThread,
MockLLMDataDist):
mock_config = self._create_mock_config()
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.assertIsNotNone(pipe._register_thread)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
pipe.router_socket.close()
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_create_register_thread_address_is_not_empty(
self, MockThread, MockLLMDataDist):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": [""],
"decode_device_ips": [""],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.assertIsNotNone(pipe._register_thread)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
pipe.router_socket.close()
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
@patch('llm_datadist.LLMDataDist')
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
MockLLMDataDist):
pipe = MockSimplePipe()
tensor = torch.randn(3, 3)
tensor_desc = MockLLMDataDist.CacheDesc(
num_tensors=1,
shape=(3, 3),
data_type=MockLLMDataDist.DataType.DT_FLOAT,
seq_len_dim_index=1)
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
self.assertIsNotNone(result)