add ut for kv tansfer module (#1531)

### What this PR does / why we need it?
test kv data transfer contains connect,pipe,buffer

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
CI passed with new added test.

---------

Signed-off-by: lixudong <lixudong@cmss.chinamobile.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: lixudong <lixudong@cmss.chinamobile.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Agonixiaoxiao
2025-07-02 16:14:52 +08:00
committed by GitHub
parent aa5fa07478
commit 7fc1a98489
3 changed files with 362 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
import unittest
import zlib
from unittest.mock import MagicMock
import torch
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
int32_hash)
class MockSimplePipe:
def __init__(self):
self.cluster_id = 0
self.send_tensor = MagicMock()
self.recv_tensor = MagicMock()
self.deallocate_buffer = MagicMock()
class TestSimpleBuffer(unittest.TestCase):
def setUp(self):
self.pipe = MockSimplePipe()
self.buffer = SimpleBuffer(self.pipe)
def test_int32_hash(self):
self.assertEqual(int32_hash("test"), zlib.adler32(b"test"))
def test_insert(self):
input_tokens = torch.tensor([1, 2, 3])
roi = torch.tensor([1, 0, 1])
key = torch.randn(2, 3, 4, 5)
value = torch.randn(2, 3, 4, 5)
hidden = torch.randn(3, 6)
self.buffer.num_layers = 2
self.buffer.num_heads = 4
self.buffer.head_size = 5
self.buffer.hidden_size = 6
self.buffer.dtype = torch.float32
self.buffer.insert(input_tokens, roi, key, value, hidden, "req1")
self.pipe.send_tensor.assert_called()
def test_drop_select(self):
input_tokens = torch.tensor([1, 2, 3])
roi = None
self.buffer.num_layers = 2
self.buffer.num_heads = 4
self.buffer.head_size = 5
self.buffer.hidden_size = 6
self.buffer.dtype = torch.float32
self.pipe.recv_tensor.side_effect = [
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
(MagicMock(), torch.randn(1, 3, 6))
]
result = self.buffer.drop_select(input_tokens, roi, "req1")
self.assertEqual(len(result), 4)
self.assertIsInstance(result[0], torch.Tensor)
self.assertIsInstance(result[1], torch.Tensor)
self.assertIsInstance(result[2], torch.Tensor)
self.assertIsNone(result[3])
self.assertEqual(result[0].shape, (2, 3, 4, 5))
def test_close(self):
self.buffer.close()

View File

@@ -0,0 +1,146 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm.config import VllmConfig
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
from vllm_ascend.distributed.kv_transfer.simple_connector import \
SimpleConnector
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
class TestSimpleConnector(unittest.TestCase):
def setUp(self):
self.mock_pipe = MagicMock(spec=SimplePipe)
self.mock_buffer = MagicMock(spec=SimpleBuffer)
patcher = patch(
'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer')
self.addCleanup(patcher.stop)
self.MockSimpleBuffer = patcher.start()
self.MockSimpleBuffer.return_value = self.mock_buffer
def _create_mock_config(self, kv_role):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
mock_config.kv_port = 5500
self.mock_config = MagicMock(spec=VllmConfig)
self.mock_config.kv_transfer_config.is_kv_producer = True
self.mock_config.model_config.hf_config.hidden_size = 128
self.mock_config.model_config.hf_config.num_attention_heads = 8
self.mock_config.model_config.hf_config.num_key_value_heads = 8
self.mock_config.model_config.hf_config.qk_rope_head_dim = 16
self.mock_config.model_config.hf_config.kv_lora_rank = 16
self.mock_config.model_config.is_deepseek_mla = True
# 模拟 parallel_config
self.mock_config.parallel_config = MagicMock()
self.mock_config.parallel_config.tensor_parallel_size = 1
self.mock_config.parallel_config.get_num_layers.return_value = 4
if kv_role == "kv_producer":
self.mock_config.kv_transfer_config.kv_role = "kv_producer"
else:
self.mock_config.kv_transfer_config.kv_role = "kv_consumer"
return mock_config
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist):
"""Test select method when buffer retrieval succeeds."""
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_producer"))
assert connector.producer_data_pipe is not None
assert connector.producer_buffer is not None
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist):
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_consumer"))
connector.consumer_data_pipe = mock_pipe
connector.consumer_buffer = mock_buffer
assert connector.consumer_data_pipe is not None
assert connector.consumer_buffer is not None
input_tokens = torch.tensor([1, 2, 3])
roi = torch.tensor([True, True, True])
req_id = "test_req"
connector.select(input_tokens, roi, req_id)
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist):
"""Test insert operation"""
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_producer"))
connector.producer_buffer = mock_buffer
input_tokens = torch.randint(0, 1000, (5, ))
roi = torch.ones_like(input_tokens, dtype=torch.bool)
keys = torch.randn(3, 5, 1, 96)
values = torch.randn(3, 5, 1, 96)
hidden = torch.randn(5, 768)
req_id = "test_req"
connector.insert(input_tokens, roi, keys, values, hidden, req_id)
mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys,
values, hidden, req_id)
@patch.object(SimpleConnector, 'insert')
@patch('torch.distributed.get_rank', return_value=0)
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
@patch('llm_datadist.LLMDataDist')
def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer,
MockLLMDataDist, mock_insert,
mock_rank):
"""Test sending KV caches and hidden states"""
connector = SimpleConnector(
rank=0,
local_rank=0,
config=self._create_mock_config("kv_producer"))
mock_model_executable = MagicMock()
mock_model_executable.model.start_layer = 0
mock_model_executable.model.end_layer = 3
mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata)
mock_model_input.input_tokens = torch.randint(0, 1000, (10, ))
mock_model_input.attn_metadata.seq_lens = [5, 5]
mock_model_input.attn_metadata.slot_mapping = torch.randint(
0, 100, (10, ))
mock_model_input.attn_metadata.num_prefill_tokens = 10
mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]}
kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)]
hidden_states = torch.randn(10, 768)
connector.send_kv_caches_and_hidden_states(mock_model_executable,
mock_model_input, kv_caches,
hidden_states)

View File

@@ -0,0 +1,145 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
class TestSimplePipe(unittest.TestCase):
@classmethod
def _create_mock_config(self):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
mock_config.kv_port = 5500
return mock_config
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_init_success(self, mock_thread, MockLLMDataDist):
mock_config = self._create_mock_config()
self.pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.pipe.router_socket.close()
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
self.pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=self._create_mock_config(),
hostname="127.0.0.1",
port_offset=0)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
self.pipe.router_socket.close()
def test_init_with_invalid_kv_role(self):
with self.assertRaises(NotImplementedError):
mock_config = MagicMock()
mock_config.kv_role = "err_role"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": ["127.0.0.1"],
"decode_device_ips": ["127.0.0.1"],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
pipe.router_socket.close()
def test_init_with_missing_device_ips(self):
with self.assertRaises(ValueError):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=0,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
pipe.router_socket.close()
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_create_register_thread_address_is_empty(self, MockThread,
MockLLMDataDist):
mock_config = self._create_mock_config()
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.assertIsNotNone(pipe._register_thread)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
pipe.router_socket.close()
@patch('threading.Thread')
@patch('llm_datadist.LLMDataDist')
def test_create_register_thread_address_is_not_empty(
self, MockThread, MockLLMDataDist):
mock_config = MagicMock()
mock_config.kv_role = "kv_producer"
mock_config.kv_connector_extra_config = {
"prefill_device_ips": [""],
"decode_device_ips": [""],
"llmdatadist_comm_port": 26000,
"http_port": 8000,
"proxy_ip": "127.0.0.1",
"proxy_port": "8000",
"port": 5500
}
pipe = SimplePipe(rank=5,
local_rank=0,
kv_transfer_config=mock_config,
hostname="127.0.0.1",
port_offset=0)
self.assertIsNotNone(pipe._register_thread)
mock_data_dist = MockLLMDataDist.return_value
mock_data_dist.init.return_value = None
pipe.router_socket.close()
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
@patch('llm_datadist.LLMDataDist')
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
MockLLMDataDist):
pipe = MockSimplePipe()
tensor = torch.randn(3, 3)
tensor_desc = MockLLMDataDist.CacheDesc(
num_tensors=1,
shape=(3, 3),
data_type=MockLLMDataDist.DataType.DT_FLOAT,
seq_len_dim_index=1)
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
self.assertIsNotNone(result)