v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

View File

@@ -0,0 +1,32 @@
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
MSEventKey)
class Testbase(TestBase):
def test_ms_event_key(self):
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
def test_ms_attention_metadata_split_config_default(self):
config = MSAttentionMetadataSplitConfig()
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
def test_ms_attention_metadata_split_config_custom(self):
config = MSAttentionMetadataSplitConfig(
num_micro_batches=4,
min_total_tokens_to_split=512,
min_prefill_tokens_to_split=128)
self.assertEqual(config.num_micro_batches, 4)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)

View File

@@ -0,0 +1,47 @@
import pytest
from pytest_mock import MockFixture
from tests.ut.base import PytestBase
from vllm_ascend.multistream.decorator import set_multistream_support
class Context:
def __init__(self, attn_metadata=None):
self.attn_metadata = attn_metadata
class TestDecorator(PytestBase):
@pytest.mark.parametrize(
'layer_context, microbatch_context, expected_metadata', [
((-1, None, None), -1, {
"original": True
}),
((-1, None, None), 0, {
"original": True
}),
((0, None, None), -1, {
"original": True
}),
((0, None, [{
"new": True
}]), 0, {
"new": True
}),
])
def test_decorator(self, mocker: MockFixture, layer_context,
microbatch_context, expected_metadata):
def context_func():
return Context(attn_metadata={"original": True})
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
return_value=layer_context)
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
return_value=microbatch_context)
context = set_multistream_support()(context_func)()
assert context.attn_metadata == expected_metadata

View File

@@ -0,0 +1,198 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import PytestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
MultiStreamPreTransformerLayer)
from vllm_ascend.multistream.metadata import MultiStreamMetadata
# === fixture: mock tensor input ===
@pytest.fixture
def input_tensors():
return [torch.randn(2, 128), torch.randn(2, 128)]
# === mock get_forward_context ===
class DummyContext:
def __init__(self, attn_metadata):
self.attn_metadata = attn_metadata
class TestMultiStreamPreTransformerLayer(PytestBase):
# === test when multistream_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta")
layer = MultiStreamPreTransformerLayer(multistream_metadata=None)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "dummy_meta"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when attn_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata=None)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out is None
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=False (no split needed) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.split_micro_batch.return_value = (False, "same_attn",
input_tensors, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "same_attn"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=True (split occurred) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
split_inputs = [[t[:1], t[1:]] for t in input_tensors]
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.start_layer = 2
dummy_metadata.split_micro_batch.return_value = (True,
["attn1", "attn2"],
split_inputs, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == ["attn1", "attn2"]
assert input_out == split_inputs
mock_set_ctx.assert_called_once_with(2, dummy_metadata,
["attn1", "attn2"])
class TestMultiStreamPostTransformerLayer(PytestBase):
def test_post_forward_metadata_none(self, input_tensors):
layer = MultiStreamPostTransformerLayer(multistream_metadata=None)
output = layer.forward(input_tensors)
assert output == input_tensors
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.ms_config = None
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
assert output == input_tensors
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx,
input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
5, # layer_index
dummy_metadata, # ms_metadata
"dummy_attn_metadata" # ms_attn_metadata
)
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
# check wait_event
dummy_metadata.try_wait_event.assert_called_once_with(
9, # end_layer - 1
3, # num_micro_batches - 1
MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx,
mock_get_ctx, input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
3, # layer_index
dummy_metadata,
"dummy_attn_metadata")
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors, wait_layer_index=7)
dummy_metadata.try_wait_event.assert_called_once_with(
7, 3, MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"

View File

@@ -0,0 +1,246 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamMetadata,
MultiStreamStepMetadata,
split_micro_batches_tensors)
class TestMetaData(TestBase):
def setUp(self):
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
self.test_tensors = torch.randn(100, 1024)
self.test_tensors_dict = {
'query': torch.randn(100, 1024),
'key': torch.randn(100, 1024),
'value': torch.randn(100, 1024)
}
self.split_index = 50
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
self.metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
def test_split_micro_batches_tensors(self):
test_tensors_list_res = split_micro_batches_tensors(
self.test_tensors_list, self.split_index)
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
self.split_index)
keys = ['query', 'key', 'value']
test_tensors_dict_res = split_micro_batches_tensors(
self.test_tensors_dict, self.split_index, keys)
for i in range(3):
self.assertEqual(len(test_tensors_list_res[i][0]),
self.split_index)
self.assertEqual(
len(test_tensors_list_res[i][0]) +
len(test_tensors_list_res[i][1]), 100)
self.assertEqual(len(test_tensors_res[0]), self.split_index)
self.assertEqual(
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
for key in keys:
self.assertEqual(len(test_tensors_dict_res[0][key]),
self.split_index)
self.assertEqual(
len(test_tensors_dict_res[0][key]) +
len(test_tensors_dict_res[1][key]), 100)
def test_default_init_multistream_step_metadata(self):
metadata = MultiStreamStepMetadata()
self.assertIsNone(metadata.comm_stream)
self.assertIsNone(metadata.before_comm_event)
self.assertIsNone(metadata.after_comm_event)
def test_custom_init_multistream_step_metadata(self):
mockStream = MagicMock(spec=torch.npu.Stream)
mockEvent1 = MagicMock(spec=torch.npu.Event)
mockEvent2 = MagicMock(spec=torch.npu.Event)
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
self.assertEqual(metadata.comm_stream, mockStream)
self.assertEqual(metadata.before_comm_event, mockEvent1)
self.assertEqual(metadata.after_comm_event, mockEvent2)
def test_default_init_multistream_config(self):
config = MultiStreamConfig()
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.imbalance_ratio, 0.1)
def test_custom_init_multistream_config(self):
config = MultiStreamConfig(512, 128, 1, 0.2)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)
self.assertEqual(config.num_micro_batches, 1)
self.assertEqual(config.imbalance_ratio, 0.2)
def test_init_multistream_metadata(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock()]
multistream_config = MagicMock(spec=MultiStreamConfig)
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertEqual(metadata.calculate_stream, mock_stream)
self.assertEqual(metadata.communicate_stream, mock_stream)
self.assertEqual(metadata.start_layer, 1)
self.assertEqual(metadata.end_layer, 3)
self.assertEqual(metadata.ms_config, multistream_config)
self.assertTrue(metadata.causal_lm)
def test_build_events(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
with patch('torch.npu.Event', return_value=mock_event):
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MultiStreamConfig(
num_micro_batches=2,
min_total_tokens_to_split=256,
min_prefill_tokens_to_split=64)
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
expected_events = {
0: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
1: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
2: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
}
}
self.assertEqual(metadata.ms_events, expected_events)
def test_build_ms_split_config(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
multistream_config.num_micro_batches = 2
multistream_config.min_total_tokens_to_split = 256
multistream_config.min_prefill_tokens_to_split = 64
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertIsNotNone(metadata.ms_split_config)
self.assertEqual(metadata.ms_split_config.num_micro_batches,
multistream_config.num_micro_batches)
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
multistream_config.min_total_tokens_to_split)
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
multistream_config.min_prefill_tokens_to_split)
def test_try_wait_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_wait_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.wait.assert_called_once()
def test_try_record_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_record_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.record.assert_called_once()
def test_merge_batches_none_input(self):
input_tensors = None
result = self.metadata.merge_micro_batches(input_tensors)
self.assertIsNone(result)
def test_merge_batches_single_tensor_input(self):
input_tensors = [torch.tensor([1, 2, 3])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 1)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
def test_merge_batches_list_of_tensors_input(self):
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertEqual(result, input_tensors)
def test_merge_batches_nested_list_input(self):
input_tensors = [[torch.tensor([1, 2]),
torch.tensor([3, 4])],
[torch.tensor([5, 6]),
torch.tensor([7, 8])]]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))

View File

@@ -0,0 +1,147 @@
from unittest.mock import MagicMock
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
model_input_split_v1_mla_attn,
split_attn_int_type,
split_attn_tensor_type)
class TestMsSplit(TestBase):
def test_decode_only(self):
result = compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.DecodeOnly,
num_tokens=10)
self.assertEqual(result, [5, 5])
def test_perfect_balance(self):
query_lens = [2, 3, 5]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [5, 2])
def test_imbalance(self):
query_lens = [1, 2, 3, 4]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_query_lens_none(self):
with self.assertRaises(AssertionError):
compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
def test_empty_query_lens(self):
query_lens: list[int] = []
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_single_query_len(self):
query_lens = [10]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_split_attn_tensor_type_middle(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 3
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_start(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 0
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_end(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 5
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_empty_tensor(self):
input_tensor = torch.tensor([])
index = 0
expected_result = [torch.tensor([]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_int_type_index_greater_than_var(self):
var = 5
index = 10
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_equal_to_var(self):
var = 5
index = 5
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_less_than_var(self):
var = 10
index = 5
expected_result = [5, 5]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_zero(self):
var = 10
index = 0
expected_result = [0, 10]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_var_zero(self):
var = 0
index = 5
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_both_zero(self):
var = 0
index = 0
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_v1_mla_attn_input_none(self):
attn_metadata = None
ascendMLAPrefillMetadata = MagicMock()
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
result = model_input_split_v1_mla_attn(attn_metadata,
ascendMLAPrefillMetadata,
ms_split_config)
self.assertEqual(result, [None])