add ut for kv tansfer module (#1531)
### What this PR does / why we need it? test kv data transfer contains connect,pipe,buffer ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? CI passed with new added test. --------- Signed-off-by: lixudong <lixudong@cmss.chinamobile.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: lixudong <lixudong@cmss.chinamobile.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
71
tests/ut/distributed/kv_transfer/test_simple_buffer.py
Normal file
71
tests/ut/distributed/kv_transfer/test_simple_buffer.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import unittest
|
||||
import zlib
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import (SimpleBuffer,
|
||||
int32_hash)
|
||||
|
||||
|
||||
class MockSimplePipe:
|
||||
|
||||
def __init__(self):
|
||||
self.cluster_id = 0
|
||||
self.send_tensor = MagicMock()
|
||||
self.recv_tensor = MagicMock()
|
||||
self.deallocate_buffer = MagicMock()
|
||||
|
||||
|
||||
class TestSimpleBuffer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.pipe = MockSimplePipe()
|
||||
self.buffer = SimpleBuffer(self.pipe)
|
||||
|
||||
def test_int32_hash(self):
|
||||
self.assertEqual(int32_hash("test"), zlib.adler32(b"test"))
|
||||
|
||||
def test_insert(self):
|
||||
input_tokens = torch.tensor([1, 2, 3])
|
||||
roi = torch.tensor([1, 0, 1])
|
||||
key = torch.randn(2, 3, 4, 5)
|
||||
value = torch.randn(2, 3, 4, 5)
|
||||
hidden = torch.randn(3, 6)
|
||||
|
||||
self.buffer.num_layers = 2
|
||||
self.buffer.num_heads = 4
|
||||
self.buffer.head_size = 5
|
||||
self.buffer.hidden_size = 6
|
||||
self.buffer.dtype = torch.float32
|
||||
|
||||
self.buffer.insert(input_tokens, roi, key, value, hidden, "req1")
|
||||
|
||||
self.pipe.send_tensor.assert_called()
|
||||
|
||||
def test_drop_select(self):
|
||||
input_tokens = torch.tensor([1, 2, 3])
|
||||
roi = None
|
||||
|
||||
self.buffer.num_layers = 2
|
||||
self.buffer.num_heads = 4
|
||||
self.buffer.head_size = 5
|
||||
self.buffer.hidden_size = 6
|
||||
self.buffer.dtype = torch.float32
|
||||
|
||||
self.pipe.recv_tensor.side_effect = [
|
||||
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
|
||||
(MagicMock(), torch.randn(1, 2, 3 * 4 * 5)),
|
||||
(MagicMock(), torch.randn(1, 3, 6))
|
||||
]
|
||||
|
||||
result = self.buffer.drop_select(input_tokens, roi, "req1")
|
||||
self.assertEqual(len(result), 4)
|
||||
self.assertIsInstance(result[0], torch.Tensor)
|
||||
self.assertIsInstance(result[1], torch.Tensor)
|
||||
self.assertIsInstance(result[2], torch.Tensor)
|
||||
self.assertIsNone(result[3])
|
||||
self.assertEqual(result[0].shape, (2, 3, 4, 5))
|
||||
|
||||
def test_close(self):
|
||||
self.buffer.close()
|
||||
146
tests/ut/distributed/kv_transfer/test_simple_connector.py
Normal file
146
tests/ut/distributed/kv_transfer/test_simple_connector.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_buffer import SimpleBuffer
|
||||
from vllm_ascend.distributed.kv_transfer.simple_connector import \
|
||||
SimpleConnector
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
|
||||
class TestSimpleConnector(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_pipe = MagicMock(spec=SimplePipe)
|
||||
self.mock_buffer = MagicMock(spec=SimpleBuffer)
|
||||
|
||||
patcher = patch(
|
||||
'vllm_ascend.distributed.kv_transfer.simple_buffer.SimpleBuffer')
|
||||
self.addCleanup(patcher.stop)
|
||||
self.MockSimpleBuffer = patcher.start()
|
||||
self.MockSimpleBuffer.return_value = self.mock_buffer
|
||||
|
||||
def _create_mock_config(self, kv_role):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
mock_config.kv_port = 5500
|
||||
self.mock_config = MagicMock(spec=VllmConfig)
|
||||
self.mock_config.kv_transfer_config.is_kv_producer = True
|
||||
self.mock_config.model_config.hf_config.hidden_size = 128
|
||||
self.mock_config.model_config.hf_config.num_attention_heads = 8
|
||||
self.mock_config.model_config.hf_config.num_key_value_heads = 8
|
||||
self.mock_config.model_config.hf_config.qk_rope_head_dim = 16
|
||||
self.mock_config.model_config.hf_config.kv_lora_rank = 16
|
||||
self.mock_config.model_config.is_deepseek_mla = True
|
||||
# 模拟 parallel_config
|
||||
self.mock_config.parallel_config = MagicMock()
|
||||
self.mock_config.parallel_config.tensor_parallel_size = 1
|
||||
self.mock_config.parallel_config.get_num_layers.return_value = 4
|
||||
|
||||
if kv_role == "kv_producer":
|
||||
self.mock_config.kv_transfer_config.kv_role = "kv_producer"
|
||||
else:
|
||||
self.mock_config.kv_transfer_config.kv_role = "kv_consumer"
|
||||
return mock_config
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_select_init(self, mock_pipe, mock_buffer, MockLLMDataDist):
|
||||
"""Test select method when buffer retrieval succeeds."""
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_producer"))
|
||||
assert connector.producer_data_pipe is not None
|
||||
assert connector.producer_buffer is not None
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_select_select(self, mock_pipe, mock_buffer, MockLLMDataDist):
|
||||
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_consumer"))
|
||||
connector.consumer_data_pipe = mock_pipe
|
||||
connector.consumer_buffer = mock_buffer
|
||||
assert connector.consumer_data_pipe is not None
|
||||
assert connector.consumer_buffer is not None
|
||||
input_tokens = torch.tensor([1, 2, 3])
|
||||
roi = torch.tensor([True, True, True])
|
||||
req_id = "test_req"
|
||||
connector.select(input_tokens, roi, req_id)
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_insert(self, mock_pipe, mock_buffer, MockLLMDataDist):
|
||||
"""Test insert operation"""
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_producer"))
|
||||
|
||||
connector.producer_buffer = mock_buffer
|
||||
|
||||
input_tokens = torch.randint(0, 1000, (5, ))
|
||||
roi = torch.ones_like(input_tokens, dtype=torch.bool)
|
||||
keys = torch.randn(3, 5, 1, 96)
|
||||
values = torch.randn(3, 5, 1, 96)
|
||||
hidden = torch.randn(5, 768)
|
||||
req_id = "test_req"
|
||||
|
||||
connector.insert(input_tokens, roi, keys, values, hidden, req_id)
|
||||
|
||||
mock_buffer.insert.assert_called_once_with(input_tokens, roi, keys,
|
||||
values, hidden, req_id)
|
||||
|
||||
@patch.object(SimpleConnector, 'insert')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimplePipe')
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_connector.SimpleBuffer')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_send_kv_caches_and_hidden_states(self, mock_pipe, mock_buffer,
|
||||
MockLLMDataDist, mock_insert,
|
||||
mock_rank):
|
||||
"""Test sending KV caches and hidden states"""
|
||||
connector = SimpleConnector(
|
||||
rank=0,
|
||||
local_rank=0,
|
||||
config=self._create_mock_config("kv_producer"))
|
||||
|
||||
mock_model_executable = MagicMock()
|
||||
mock_model_executable.model.start_layer = 0
|
||||
mock_model_executable.model.end_layer = 3
|
||||
|
||||
mock_model_input = MagicMock(spec=ModelInputForGPUWithSamplingMetadata)
|
||||
mock_model_input.input_tokens = torch.randint(0, 1000, (10, ))
|
||||
mock_model_input.attn_metadata.seq_lens = [5, 5]
|
||||
mock_model_input.attn_metadata.slot_mapping = torch.randint(
|
||||
0, 100, (10, ))
|
||||
mock_model_input.attn_metadata.num_prefill_tokens = 10
|
||||
mock_model_input.request_ids_to_seq_ids = {"req1": [0], "req2": [1]}
|
||||
|
||||
kv_caches = [torch.randn(2, 100, 1, 96) for _ in range(3)]
|
||||
|
||||
hidden_states = torch.randn(10, 768)
|
||||
|
||||
connector.send_kv_caches_and_hidden_states(mock_model_executable,
|
||||
mock_model_input, kv_caches,
|
||||
hidden_states)
|
||||
145
tests/ut/distributed/kv_transfer/test_simple_pipe.py
Normal file
145
tests/ut/distributed/kv_transfer/test_simple_pipe.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.distributed.kv_transfer.simple_pipe import SimplePipe
|
||||
|
||||
|
||||
class TestSimplePipe(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def _create_mock_config(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
mock_config.kv_port = 5500
|
||||
return mock_config
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_init_success(self, mock_thread, MockLLMDataDist):
|
||||
|
||||
mock_config = self._create_mock_config()
|
||||
|
||||
self.pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
|
||||
self.pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_prepare_data_dist(self, mock_thread, MockLLMDataDist):
|
||||
self.pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=self._create_mock_config(),
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
self.pipe.router_socket.close()
|
||||
|
||||
def test_init_with_invalid_kv_role(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "err_role"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": ["127.0.0.1"],
|
||||
"decode_device_ips": ["127.0.0.1"],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
pipe.router_socket.close()
|
||||
|
||||
def test_init_with_missing_device_ips(self):
|
||||
with self.assertRaises(ValueError):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=0,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_create_register_thread_address_is_empty(self, MockThread,
|
||||
MockLLMDataDist):
|
||||
|
||||
mock_config = self._create_mock_config()
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
self.assertIsNotNone(pipe._register_thread)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('threading.Thread')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_create_register_thread_address_is_not_empty(
|
||||
self, MockThread, MockLLMDataDist):
|
||||
mock_config = MagicMock()
|
||||
mock_config.kv_role = "kv_producer"
|
||||
mock_config.kv_connector_extra_config = {
|
||||
"prefill_device_ips": [""],
|
||||
"decode_device_ips": [""],
|
||||
"llmdatadist_comm_port": 26000,
|
||||
"http_port": 8000,
|
||||
"proxy_ip": "127.0.0.1",
|
||||
"proxy_port": "8000",
|
||||
"port": 5500
|
||||
}
|
||||
pipe = SimplePipe(rank=5,
|
||||
local_rank=0,
|
||||
kv_transfer_config=mock_config,
|
||||
hostname="127.0.0.1",
|
||||
port_offset=0)
|
||||
self.assertIsNotNone(pipe._register_thread)
|
||||
mock_data_dist = MockLLMDataDist.return_value
|
||||
mock_data_dist.init.return_value = None
|
||||
pipe.router_socket.close()
|
||||
|
||||
@patch('vllm_ascend.distributed.kv_transfer.simple_pipe.SimplePipe')
|
||||
@patch('llm_datadist.LLMDataDist')
|
||||
def test_should_send_tensor_when_valid_input(self, MockSimplePipe,
|
||||
MockLLMDataDist):
|
||||
pipe = MockSimplePipe()
|
||||
tensor = torch.randn(3, 3)
|
||||
tensor_desc = MockLLMDataDist.CacheDesc(
|
||||
num_tensors=1,
|
||||
shape=(3, 3),
|
||||
data_type=MockLLMDataDist.DataType.DT_FLOAT,
|
||||
seq_len_dim_index=1)
|
||||
tensor_key = MockLLMDataDist.CacheKey(1, 0, 1)
|
||||
result = pipe.send_tensor(tensor, tensor_desc, tensor_key)
|
||||
self.assertIsNotNone(result)
|
||||
Reference in New Issue
Block a user