[Test] Add ut for files in /multistream (#1947)
### What this PR does / why we need it?
Add some uts for files in folder /multistream
### Does this PR introduce _any_ user-facing change?
No
- vLLM version: v0.9.2
- vLLM main:
b77c7d327f
Signed-off-by: lwq <liwenquan5@huawei.com>
Co-authored-by: lwq <liwenquan5@huawei.com>
This commit is contained in:
32
tests/ut/multistream/test_base.py
Normal file
32
tests/ut/multistream/test_base.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
|
||||
MSEventKey)
|
||||
|
||||
|
||||
class Testbase(TestBase):
|
||||
|
||||
def test_ms_event_key(self):
|
||||
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
|
||||
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
|
||||
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
|
||||
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
|
||||
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
|
||||
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
|
||||
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
|
||||
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
|
||||
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
|
||||
|
||||
def test_ms_attention_metadata_split_config_default(self):
|
||||
config = MSAttentionMetadataSplitConfig()
|
||||
self.assertEqual(config.num_micro_batches, 2)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 256)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 64)
|
||||
|
||||
def test_ms_attention_metadata_split_config_custom(self):
|
||||
config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=4,
|
||||
min_total_tokens_to_split=512,
|
||||
min_prefill_tokens_to_split=128)
|
||||
self.assertEqual(config.num_micro_batches, 4)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 512)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 128)
|
||||
246
tests/ut/multistream/test_metadata.py
Normal file
246
tests/ut/multistream/test_metadata.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||
MultiStreamMetadata,
|
||||
MultiStreamStepMetadata,
|
||||
split_micro_batches_tensors)
|
||||
|
||||
|
||||
class TestMetaData(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
|
||||
self.test_tensors = torch.randn(100, 1024)
|
||||
self.test_tensors_dict = {
|
||||
'query': torch.randn(100, 1024),
|
||||
'key': torch.randn(100, 1024),
|
||||
'value': torch.randn(100, 1024)
|
||||
}
|
||||
self.split_index = 50
|
||||
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
|
||||
self.metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
def test_split_micro_batches_tensors(self):
|
||||
test_tensors_list_res = split_micro_batches_tensors(
|
||||
self.test_tensors_list, self.split_index)
|
||||
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
|
||||
self.split_index)
|
||||
keys = ['query', 'key', 'value']
|
||||
test_tensors_dict_res = split_micro_batches_tensors(
|
||||
self.test_tensors_dict, self.split_index, keys)
|
||||
for i in range(3):
|
||||
self.assertEqual(len(test_tensors_list_res[i][0]),
|
||||
self.split_index)
|
||||
|
||||
self.assertEqual(
|
||||
len(test_tensors_list_res[i][0]) +
|
||||
len(test_tensors_list_res[i][1]), 100)
|
||||
|
||||
self.assertEqual(len(test_tensors_res[0]), self.split_index)
|
||||
self.assertEqual(
|
||||
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
|
||||
|
||||
for key in keys:
|
||||
self.assertEqual(len(test_tensors_dict_res[0][key]),
|
||||
self.split_index)
|
||||
self.assertEqual(
|
||||
len(test_tensors_dict_res[0][key]) +
|
||||
len(test_tensors_dict_res[1][key]), 100)
|
||||
|
||||
def test_default_init_multistream_step_metadata(self):
|
||||
metadata = MultiStreamStepMetadata()
|
||||
self.assertIsNone(metadata.comm_stream)
|
||||
self.assertIsNone(metadata.before_comm_event)
|
||||
self.assertIsNone(metadata.after_comm_event)
|
||||
|
||||
def test_custom_init_multistream_step_metadata(self):
|
||||
mockStream = MagicMock(spec=torch.npu.Stream)
|
||||
mockEvent1 = MagicMock(spec=torch.npu.Event)
|
||||
mockEvent2 = MagicMock(spec=torch.npu.Event)
|
||||
|
||||
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
|
||||
self.assertEqual(metadata.comm_stream, mockStream)
|
||||
self.assertEqual(metadata.before_comm_event, mockEvent1)
|
||||
self.assertEqual(metadata.after_comm_event, mockEvent2)
|
||||
|
||||
def test_default_init_multistream_config(self):
|
||||
config = MultiStreamConfig()
|
||||
self.assertEqual(config.min_total_tokens_to_split, 256)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 64)
|
||||
self.assertEqual(config.num_micro_batches, 2)
|
||||
self.assertEqual(config.imbalance_ratio, 0.1)
|
||||
|
||||
def test_custom_init_multistream_config(self):
|
||||
config = MultiStreamConfig(512, 128, 1, 0.2)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 512)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 128)
|
||||
self.assertEqual(config.num_micro_batches, 1)
|
||||
self.assertEqual(config.imbalance_ratio, 0.2)
|
||||
|
||||
def test_init_multistream_metadata(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
|
||||
event_keys = [MagicMock()]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
|
||||
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
self.assertEqual(metadata.calculate_stream, mock_stream)
|
||||
self.assertEqual(metadata.communicate_stream, mock_stream)
|
||||
self.assertEqual(metadata.start_layer, 1)
|
||||
self.assertEqual(metadata.end_layer, 3)
|
||||
self.assertEqual(metadata.ms_config, multistream_config)
|
||||
self.assertTrue(metadata.causal_lm)
|
||||
|
||||
def test_build_events(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MultiStreamConfig(
|
||||
num_micro_batches=2,
|
||||
min_total_tokens_to_split=256,
|
||||
min_prefill_tokens_to_split=64)
|
||||
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
expected_events = {
|
||||
0: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
},
|
||||
1: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
},
|
||||
2: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
}
|
||||
}
|
||||
self.assertEqual(metadata.ms_events, expected_events)
|
||||
|
||||
def test_build_ms_split_config(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
multistream_config.num_micro_batches = 2
|
||||
multistream_config.min_total_tokens_to_split = 256
|
||||
multistream_config.min_prefill_tokens_to_split = 64
|
||||
|
||||
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
self.assertIsNotNone(metadata.ms_split_config)
|
||||
self.assertEqual(metadata.ms_split_config.num_micro_batches,
|
||||
multistream_config.num_micro_batches)
|
||||
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
|
||||
multistream_config.min_total_tokens_to_split)
|
||||
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
|
||||
multistream_config.min_prefill_tokens_to_split)
|
||||
|
||||
def test_try_wait_event(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
metadata.try_wait_event(layer_index=1,
|
||||
micro_batch_index=0,
|
||||
event_key=event_keys[0])
|
||||
mock_event.wait.assert_called_once()
|
||||
|
||||
def test_try_record_event(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
metadata.try_record_event(layer_index=1,
|
||||
micro_batch_index=0,
|
||||
event_key=event_keys[0])
|
||||
mock_event.record.assert_called_once()
|
||||
|
||||
def test_merge_batches_none_input(self):
|
||||
input_tensors = None
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_merge_batches_single_tensor_input(self):
|
||||
input_tensors = [torch.tensor([1, 2, 3])]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
|
||||
|
||||
def test_merge_batches_list_of_tensors_input(self):
|
||||
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result, input_tensors)
|
||||
|
||||
def test_merge_batches_nested_list_input(self):
|
||||
input_tensors = [[torch.tensor([1, 2]),
|
||||
torch.tensor([3, 4])],
|
||||
[torch.tensor([5, 6]),
|
||||
torch.tensor([7, 8])]]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
|
||||
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))
|
||||
147
tests/ut/multistream/test_ms_split.py
Normal file
147
tests/ut/multistream/test_ms_split.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
|
||||
model_input_split_v1_mla_attn,
|
||||
split_attn_int_type,
|
||||
split_attn_tensor_type)
|
||||
|
||||
|
||||
class TestMsSplit(TestBase):
|
||||
|
||||
def test_decode_only(self):
|
||||
result = compute_split_seq_index(
|
||||
query_lens=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [5, 5])
|
||||
|
||||
def test_perfect_balance(self):
|
||||
query_lens = [2, 3, 5]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [5, 2])
|
||||
|
||||
def test_imbalance(self):
|
||||
query_lens = [1, 2, 3, 4]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_query_lens_none(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
compute_split_seq_index(
|
||||
query_lens=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
|
||||
def test_empty_query_lens(self):
|
||||
query_lens: list[int] = []
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_single_query_len(self):
|
||||
query_lens = [10]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_split_attn_tensor_type_middle(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 3
|
||||
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_start(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 0
|
||||
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_end(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 5
|
||||
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_empty_tensor(self):
|
||||
input_tensor = torch.tensor([])
|
||||
index = 0
|
||||
expected_result = [torch.tensor([]), torch.tensor([])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_int_type_index_greater_than_var(self):
|
||||
var = 5
|
||||
index = 10
|
||||
expected_result = [5, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_equal_to_var(self):
|
||||
var = 5
|
||||
index = 5
|
||||
expected_result = [5, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_less_than_var(self):
|
||||
var = 10
|
||||
index = 5
|
||||
expected_result = [5, 5]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_zero(self):
|
||||
var = 10
|
||||
index = 0
|
||||
expected_result = [0, 10]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_var_zero(self):
|
||||
var = 0
|
||||
index = 5
|
||||
expected_result = [0, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_both_zero(self):
|
||||
var = 0
|
||||
index = 0
|
||||
expected_result = [0, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_v1_mla_attn_input_none(self):
|
||||
attn_metadata = None
|
||||
ascendMLAPrefillMetadata = MagicMock()
|
||||
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
|
||||
result = model_input_split_v1_mla_attn(attn_metadata,
|
||||
ascendMLAPrefillMetadata,
|
||||
ms_split_config)
|
||||
self.assertEqual(result, [None])
|
||||
Reference in New Issue
Block a user