Files
xc-llm-ascend/tests/ut/multistream/test_metadata.py
SunnyLee151064 ab7d5aca5d [Test] Add ut for files in /multistream (#1947)
### What this PR does / why we need it?
Add some uts for files in folder /multistream

### Does this PR introduce _any_ user-facing change?
No

- vLLM version: v0.9.2
- vLLM main:
b77c7d327f

Signed-off-by: lwq <liwenquan5@huawei.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
2025-07-24 10:42:49 +08:00

247 lines
10 KiB
Python

from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamMetadata,
MultiStreamStepMetadata,
split_micro_batches_tensors)
class TestMetaData(TestBase):
def setUp(self):
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
self.test_tensors = torch.randn(100, 1024)
self.test_tensors_dict = {
'query': torch.randn(100, 1024),
'key': torch.randn(100, 1024),
'value': torch.randn(100, 1024)
}
self.split_index = 50
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
self.metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
def test_split_micro_batches_tensors(self):
test_tensors_list_res = split_micro_batches_tensors(
self.test_tensors_list, self.split_index)
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
self.split_index)
keys = ['query', 'key', 'value']
test_tensors_dict_res = split_micro_batches_tensors(
self.test_tensors_dict, self.split_index, keys)
for i in range(3):
self.assertEqual(len(test_tensors_list_res[i][0]),
self.split_index)
self.assertEqual(
len(test_tensors_list_res[i][0]) +
len(test_tensors_list_res[i][1]), 100)
self.assertEqual(len(test_tensors_res[0]), self.split_index)
self.assertEqual(
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
for key in keys:
self.assertEqual(len(test_tensors_dict_res[0][key]),
self.split_index)
self.assertEqual(
len(test_tensors_dict_res[0][key]) +
len(test_tensors_dict_res[1][key]), 100)
def test_default_init_multistream_step_metadata(self):
metadata = MultiStreamStepMetadata()
self.assertIsNone(metadata.comm_stream)
self.assertIsNone(metadata.before_comm_event)
self.assertIsNone(metadata.after_comm_event)
def test_custom_init_multistream_step_metadata(self):
mockStream = MagicMock(spec=torch.npu.Stream)
mockEvent1 = MagicMock(spec=torch.npu.Event)
mockEvent2 = MagicMock(spec=torch.npu.Event)
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
self.assertEqual(metadata.comm_stream, mockStream)
self.assertEqual(metadata.before_comm_event, mockEvent1)
self.assertEqual(metadata.after_comm_event, mockEvent2)
def test_default_init_multistream_config(self):
config = MultiStreamConfig()
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.imbalance_ratio, 0.1)
def test_custom_init_multistream_config(self):
config = MultiStreamConfig(512, 128, 1, 0.2)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)
self.assertEqual(config.num_micro_batches, 1)
self.assertEqual(config.imbalance_ratio, 0.2)
def test_init_multistream_metadata(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock()]
multistream_config = MagicMock(spec=MultiStreamConfig)
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertEqual(metadata.calculate_stream, mock_stream)
self.assertEqual(metadata.communicate_stream, mock_stream)
self.assertEqual(metadata.start_layer, 1)
self.assertEqual(metadata.end_layer, 3)
self.assertEqual(metadata.ms_config, multistream_config)
self.assertTrue(metadata.causal_lm)
def test_build_events(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
with patch('torch.npu.Event', return_value=mock_event):
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MultiStreamConfig(
num_micro_batches=2,
min_total_tokens_to_split=256,
min_prefill_tokens_to_split=64)
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
expected_events = {
0: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
1: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
2: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
}
}
self.assertEqual(metadata.ms_events, expected_events)
def test_build_ms_split_config(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
multistream_config.num_micro_batches = 2
multistream_config.min_total_tokens_to_split = 256
multistream_config.min_prefill_tokens_to_split = 64
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertIsNotNone(metadata.ms_split_config)
self.assertEqual(metadata.ms_split_config.num_micro_batches,
multistream_config.num_micro_batches)
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
multistream_config.min_total_tokens_to_split)
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
multistream_config.min_prefill_tokens_to_split)
def test_try_wait_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_wait_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.wait.assert_called_once()
def test_try_record_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_record_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.record.assert_called_once()
def test_merge_batches_none_input(self):
input_tensors = None
result = self.metadata.merge_micro_batches(input_tensors)
self.assertIsNone(result)
def test_merge_batches_single_tensor_input(self):
input_tensors = [torch.tensor([1, 2, 3])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 1)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
def test_merge_batches_list_of_tensors_input(self):
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertEqual(result, input_tensors)
def test_merge_batches_nested_list_input(self):
input_tensors = [[torch.tensor([1, 2]),
torch.tensor([3, 4])],
[torch.tensor([5, 6]),
torch.tensor([7, 8])]]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))