### What this PR does / why we need it? test kv data transfer contains connect,pipe,buffer ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: lixudong <lixudong@cmss.chinamobile.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: lixudong <lixudong@cmss.chinamobile.com> Co-authored-by: MengqingCao <cmq0113@163.com>
146 lines
5.5 KiB
Python
146 lines
5.5 KiB
Python
import unittest
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import torch
|
|
|
|
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
|
|
|
|
|
class TestSimplePipe(unittest.TestCase):
|
|
|
|
@classmethod
|
|
def _create_mock_config(self):
|
|
mock_config = MagicMock()
|
|
mock_config.kv_role = "kv_producer"
|
|
mock_config.kv_connector_extra_config = {
|
|
"prefill_device_ips": ["127.0.0.1"],
|
|
"decode_device_ips": ["127.0.0.1"],
|
|
"llmdatadist_comm_port": 26000,
|
|
"http_port": 8000,
|
|
"proxy_ip": "127.0.0.1",
|
|
"proxy_port": "8000",
|
|
"port": 5500
|
|
}
|
|
mock_config.kv_port = 5500
|
|
return mock_config
|
|
|
|
@patch('threading.Thread')
|
|
@patch('llm_datadist.LLMDataDist')
|
|
def test_init_success(self, mock_thread, MockLLMDataDist):
|
|
|
|
mock_config = self._create_mock_config()
|
|
|
|
self.pipe = SimplePipe(rank=5,
|
|
local_rank=0,
|
|
kv_transfer_config=mock_config,
|
|
hostname="127.0.0.1",
|
|
port_offset=0)
|
|
|
|
self.pipe.router_socket.close()
|
|
|
|
@patch('threading.Thread')
|
|
@patch('llm_datadist.LLMDataDist')
|
|
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
|
|
self.pipe = SimplePipe(rank=5,
|
|
local_rank=0,
|
|
kv_transfer_config=self._create_mock_config(),
|
|
hostname="127.0.0.1",
|
|
port_offset=0)
|
|
mock_data_dist = MockLLMDataDist.return_value
|
|
mock_data_dist.init.return_value = None
|
|
self.pipe.router_socket.close()
|
|
|
|
def test_init_with_invalid_kv_role(self):
|
|
with self.assertRaises(NotImplementedError):
|
|
mock_config = MagicMock()
|
|
mock_config.kv_role = "err_role"
|
|
mock_config.kv_connector_extra_config = {
|
|
"prefill_device_ips": ["127.0.0.1"],
|
|
"decode_device_ips": ["127.0.0.1"],
|
|
"llmdatadist_comm_port": 26000,
|
|
"http_port": 8000,
|
|
"proxy_ip": "127.0.0.1",
|
|
"proxy_port": "8000",
|
|
"port": 5500
|
|
}
|
|
pipe = SimplePipe(rank=5,
|
|
local_rank=0,
|
|
kv_transfer_config=mock_config,
|
|
hostname="127.0.0.1",
|
|
port_offset=0)
|
|
pipe.router_socket.close()
|
|
|
|
def test_init_with_missing_device_ips(self):
|
|
with self.assertRaises(ValueError):
|
|
mock_config = MagicMock()
|
|
mock_config.kv_role = "kv_producer"
|
|
mock_config.kv_connector_extra_config = {
|
|
"llmdatadist_comm_port": 26000,
|
|
"http_port": 8000,
|
|
"proxy_ip": "127.0.0.1",
|
|
"proxy_port": "8000",
|
|
"port": 5500
|
|
}
|
|
pipe = SimplePipe(rank=0,
|
|
local_rank=0,
|
|
kv_transfer_config=mock_config,
|
|
hostname="127.0.0.1",
|
|
port_offset=0)
|
|
pipe.router_socket.close()
|
|
|
|
@patch('threading.Thread')
|
|
@patch('llm_datadist.LLMDataDist')
|
|
def test_create_register_thread_address_is_empty(self, MockThread,
|
|
MockLLMDataDist):
|
|
|
|
mock_config = self._create_mock_config()
|
|
pipe = SimplePipe(rank=5,
|
|
local_rank=0,
|
|
kv_transfer_config=mock_config,
|
|
hostname="127.0.0.1",
|
|
port_offset=0)
|
|
self.assertIsNotNone(pipe._register_thread)
|
|
mock_data_dist = MockLLMDataDist.return_value
|
|
mock_data_dist.init.return_value = None
|
|
pipe.router_socket.close()
|
|
|
|
@patch('threading.Thread')
|
|
@patch('llm_datadist.LLMDataDist')
|
|
def test_create_register_thread_address_is_not_empty(
|
|
self, MockThread, MockLLMDataDist):
|
|
mock_config = MagicMock()
|
|
mock_config.kv_role = "kv_producer"
|
|
mock_config.kv_connector_extra_config = {
|
|
"prefill_device_ips": [""],
|
|
"decode_device_ips": [""],
|
|
"llmdatadist_comm_port": 26000,
|
|
"http_port": 8000,
|
|
"proxy_ip": "127.0.0.1",
|
|
"proxy_port": "8000",
|
|
"port": 5500
|
|
}
|
|
pipe = SimplePipe(rank=5,
|
|
local_rank=0,
|
|
kv_transfer_config=mock_config,
|
|
hostname="127.0.0.1",
|
|
port_offset=0)
|
|
self.assertIsNotNone(pipe._register_thread)
|
|
mock_data_dist = MockLLMDataDist.return_value
|
|
mock_data_dist.init.return_value = None
|
|
pipe.router_socket.close()
|
|
|
|
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
|
|
@patch('llm_datadist.LLMDataDist')
|
|
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
|
|
MockLLMDataDist):
|
|
pipe = MockSimplePipe()
|
|
tensor = torch.randn(3, 3)
|
|
tensor_desc = MockLLMDataDist.CacheDesc(
|
|
num_tensors=1,
|
|
shape=(3, 3),
|
|
data_type=MockLLMDataDist.DataType.DT_FLOAT,
|
|
seq_len_dim_index=1)
|
|
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
|
|
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
|
|
self.assertIsNotNone(result)
|