add ut for kv tansfer module (#1531)
### What this PR does / why we need it? test kv data transfer contains connect,pipe,buffer ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: lixudong <lixudong@cmss.chinamobile.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: lixudong <lixudong@cmss.chinamobile.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
145
tests/ut/distributed/kv_transfer/test_simple_pipe.py
Normal file
145
tests/ut/distributed/kv_transfer/test_simple_pipe.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
|
||||
class TestSimplePipe(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def _create_mock_config(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
mock_config.kv_port = 5500
|
||||
return mock_config
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_init_success(self, mock_thread, MockLLMDataDist):
|
||||
|
||||
mock_config = self._create_mock_config()
|
||||
|
||||
self.pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
|
||||
self.pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
|
||||
self.pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=self._create_mock_config(),
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
self.pipe.router_socket.close()
|
||||
|
||||
def test_init_with_invalid_kv_role(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "err_role"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
pipe.router_socket.close()
|
||||
|
||||
def test_init_with_missing_device_ips(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=0,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_create_register_thread_address_is_empty(self, MockThread,
|
||||
MockLLMDataDist):
|
||||
|
||||
mock_config = self._create_mock_config()
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
self.assertIsNotNone(pipe._register_thread)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_create_register_thread_address_is_not_empty(
|
||||
self, MockThread, MockLLMDataDist):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": [""],
|
||||
"decode_device_ips": [""],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
self.assertIsNotNone(pipe._register_thread)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
|
||||
MockLLMDataDist):
|
||||
pipe = MockSimplePipe()
|
||||
tensor = torch.randn(3, 3)
|
||||
tensor_desc = MockLLMDataDist.CacheDesc(
|
||||
num_tensors=1,
|
||||
shape=(3, 3),
|
||||
data_type=MockLLMDataDist.DataType.DT_FLOAT,
|
||||
seq_len_dim_index=1)
|
||||
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
|
||||
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
|
||||
self.assertIsNotNone(result)
|
||||
Reference in New Issue
Block a user