v0.10.1rc1

This commit is contained in:
2025-09-09 09:40:35 +08:00
parent d6f6ef41fe
commit 9149384e03
432 changed files with 84698 additions and 1 deletions

0
tests/ut/__init__.py Normal file
View File

View File

@@ -0,0 +1,133 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
class TestAttentionMaskBuilder(TestBase):
def test_init_attention_mask_builder(self):
# generate attention_mask_builder with float16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.float16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# generate attention_mask_builder with bfloat16
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(1, dtype=torch.bfloat16))
def test_get_mask_scale_factor(self):
# supported data types
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
self.assertEqual(
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
# Otherwise raise ValueError
with self.assertRaises(ValueError):
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
def test_get_attn_mask(self):
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (512, 512))
self.assertEqual(attn_mask[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(1024, 1024))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
attn_mask = attention_mask_builder.get_attn_mask(
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
self.assertEqual(attn_mask.shape, (2048, 2048))
self.assertEqual(attn_mask[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
(2048, 2048))
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
torch.tensor(float("-inf"), dtype=torch.float16))
def test_get_splitfuse_attn_mask(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
dtype=torch.float16)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (6, 100))
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 3000, 2000]),
position=torch.tensor([7, 8, 9, 2999, 1999]),
dtype=torch.float16,
device=torch.device("cpu"),
)
self.assertEqual(attn_mask.shape, (5, 3000))
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
# otherwise raise ValueError
with self.assertRaises(ValueError):
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([10, 20, 100]),
position=torch.tensor([7, 8, 9, 18, 19, 99]),
dtype=torch.int8,
device=torch.device("cpu"),
)
def test_mask_value_cleanliness(self):
attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
dtype=torch.bfloat16)
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
seq_lens=torch.tensor([6]),
position=torch.tensor([3, 4, 5]),
dtype=torch.bfloat16,
device=torch.device("cpu"),
)
self.assertEqual(
attn_mask[-2][-1],
torch.tensor(-10000, dtype=torch.bfloat16,
device=attn_mask.device))
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
torch.tensor(1, dtype=torch.bfloat16))

View File

@@ -0,0 +1,578 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
AscendAttentionBackendImpl,
AscendAttentionMetadataBuilder,
AscendAttentionState,
AscendMetadata,
CommonAttentionState)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
class TestAscendAttentionBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")
def test_get_impl_cls(self):
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
AscendAttentionBackendImpl)
def test_get_metadata_cls(self):
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),
AscendMetadata)
def test_get_state_cls(self):
self.assertEqual(AscendAttentionBackend.get_state_cls(),
CommonAttentionState)
def test_get_builder_cls(self):
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
AscendAttentionMetadataBuilder)
@patch('vllm_ascend.attention.attention_v1.is_310p')
def test_get_kv_cache_shape_310p(self, mock_is_310p):
mock_is_310p.return_value = True
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
def test_get_kv_cache_shape_not_310p(self, mock_is_310p):
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 20, 30, 40))
def test_get_bsh_kv_cache_shape(self):
result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40)
self.assertEqual(result, (2, 10, 20, 30 * 40))
def test_swap_blocks(self):
src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
src_to_dst = torch.tensor([[0, 1], [2, 3]])
AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache,
src_to_dst)
self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0]))
self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2]))
def test_copy_blocks(self):
kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))]
src_to_dists = torch.tensor([[0, 1], [2, 3]])
AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists)
self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0]))
self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2]))
class TestAscendAttentionMetadataBuilder(TestBase):
def setUp(self):
self.mock_vllm_config = MagicMock()
self.mock_vllm_config.model_config.max_model_len = 640
self.mock_vllm_config.cache_config.block_size = 64
self.mock_device = 'cpu:0'
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
self.mock_device)
def test_reorder_batch(self):
mock_input_batch = MagicMock()
mock_scheduler_output = MagicMock()
result = self.builder.reorder_batch(mock_input_batch,
mock_scheduler_output)
self.assertFalse(result)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@patch('vllm_ascend.utils.nd_to_nz_2d')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
mock_npu_format_cast,
mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 3, 7]),
query_start_loc_cpu=torch.tensor([0, 3, 7]),
seq_lens_cpu=torch.tensor([5, 6]),
num_reqs=2,
num_actual_tokens=10,
max_query_len=5,
decode_token_per_req=torch.tensor([1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((10, 10)),
spec_attn_mask=None,
attn_state=AscendAttentionState.PrefillNoCache)
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_2d.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('torch_npu.npu_format_cast')
@patch('vllm_ascend.utils.nd_to_nz_spec')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
def test_build_chunked_prefill(self, mock_ascend_attention_state,
mock_is_310p, mock_nd_to_nz_spec,
mock_npu_format_cast, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_ascend_attention_state = MagicMock()
mock_ascend_attention_state.PrefillNoCache = 0
mock_nz_tensor = MagicMock()
mock_model = MagicMock()
mock_nd_to_nz_spec.return_value = mock_nz_tensor
mock_npu_format_cast.return_value = mock_nz_tensor
self.builder.build(common_attn_metadata, mock_model)
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 2, 5, 9]),
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
seq_lens_cpu=torch.tensor([4, 5, 6]),
num_reqs=3,
num_actual_tokens=15,
max_query_len=6,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([10, 10]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
mock_model = MagicMock()
self.builder.build(common_attn_metadata, mock_model)
class TestAscendAttentionBackendImpl(TestBase):
def setUp(self):
self.layer = MagicMock()
self.layer.layer_name = "test_layer"
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
self.attn_metadata = MagicMock()
self.attn_metadata.return_value = "1"
self.layer_no_quant = MagicMock(
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
self.layer_no_quant.layer_name = "test_layer"
self.layer_no_quant._k_scale_float = 1.0
self.layer_no_quant._v_scale_float = 1.0
self.impl = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
self.impl_192 = AscendAttentionBackendImpl(
num_heads=8,
head_size=192,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
self.impl_error = AscendAttentionBackendImpl(
num_heads=8,
head_size=192,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None)
self.impl_swa = AscendAttentionBackendImpl(
num_heads=8,
head_size=64,
scale=1.0,
num_kv_heads=8,
alibi_slopes=None,
sliding_window=1024,
kv_cache_dtype="float16",
logits_soft_cap=None,
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
def test_forward_trace_flag_true(self, mock_unified_attention):
"""Test forward pass when trace_flag is True"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
metadata = self.attn_metadata
layer = self.layer
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=True)
mock_unified_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_with_quant_method(self, mock_paged_attention):
"""Test forward pass when layer has quant_method"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
kv_cache = [k_cache, v_cache]
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
metadata = MagicMock()
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
metadata.block_tables = torch.randn(10, 8 * 64)
metadata.seq_lens = torch.randn(10, 8 * 64)
metadata.attn_mask = torch.randn(10, 8 * 64)
metadata.query_lens = torch.randn(10, 8 * 64)
layer = self.layer
layer.quant_method = MagicMock()
layer.quant_method.apply.return_value = ret_value
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
layer.quant_method.apply.assert_called_once()
assert output.shape == (10, 8 * 64)
def test_forward_no_attn_metadata(self):
"""Test forward pass when attn_metadata is None"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
None,
trace_flag=False)
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache(self, mock_flash_attention,
mock_reshape_cache):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention')
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
mock_reshape_cache):
"""Test forward pass in PrefillNoCache state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillNoCache
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.seq_lens = torch.tensor([10])
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_flash_attention_qlens')
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
mock_npu_reshape_and_cache):
"""Test forward pass in PrefillCacheHit state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.PrefillCacheHit
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
def test_forward_decode_only(self, mock_paged_attention,
mock_npu_reshape_and_cache):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
mock_npu_reshape_and_cache):
"""Test forward pass in DecodeOnly state"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10] * 10)
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 100
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
print(output.shape)
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
@patch('torch_npu._npu_reshape_and_cache')
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
def test_forward_head_size_192(self, mock_vanilla_prefill,
mock_npu_reshape_and_cache, mock_is_310p):
"""Test forward pass when head_size is 192"""
self.impl.head_size = 192
query = torch.randn(10, 8 * 192)
key = torch.randn(10, 8 * 192)
value = torch.randn(10, 8 * 192)
kv_cache = torch.empty(2, 5, 128, 8, 192)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_vanilla_prefill.return_value = MagicMock()
output = self.impl_192.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_normal_v1_situation(self, mock_paged_attention,
mock_npu_reshape_and_cache):
"""Test forward pass in normal V1 situation"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu.npu_format_cast')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention_splitfuse')
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
mock_npu_reshape_and_cache,
mock_npu_format_cast):
"""Test forward pass on 310P device"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_npu_format_cast.return_value = metadata.attn_mask
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('torch_npu._npu_reshape_and_cache')
def test_forward_raise_error(self, mock_paged_attention):
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_mask = torch.randn(1, 1, 10, 10)
metadata.query_lens = torch.tensor([10])
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
with self.assertRaises(NotImplementedError):
self.impl_error.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)

View File

@@ -0,0 +1,631 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
AscendMLADecodeMetadata,
AscendMLAImpl, AscendMLAMetadata,
AscendMLAMetadataBuilder,
AscendMLAPrefillMetadata)
class TestAscendMLABackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendMLABackend.get_name(), "ASCEND_MLA")
def test_get_metadata_cls(self):
self.assertEqual(AscendMLABackend.get_metadata_cls(),
AscendMLAMetadata)
def test_get_builder_cls(self):
self.assertEqual(AscendMLABackend.get_builder_cls(),
AscendMLAMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendMLABackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendMLABackend.get_impl_cls()
self.assertEqual(result, AscendMLAImpl)
class TestAscendMLAPrefillMetadata(TestBase):
def test_ascend_mla_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendMLAPrefillMetadata(attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
class TestAscendMLADecodeMetadata(TestBase):
def test_ascend_mla_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendMLADecodeMetadata(input_positions, block_table,
seq_lens, max_seq_lens,
seq_lens_list, attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendMLAMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendMLAMetadata(num_actual_tokens, slot_mapping,
query_start_loc, seq_lens, block_tables,
num_decodes, num_decode_tokens,
num_prefills, num_input_tokens,
query_lens, head_dim, attn_mask,
attn_state, decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendMLAMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config = MagicMock()
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
def test_reorder_batch(self):
ascend_config = MagicMock()
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
return_value=ascend_config):
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
builder.decode_threshold = 1
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
scheduler_output.scheduled_spec_decode_tokens = {
0: [],
1: [1],
2: [],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)
class TestAscendMLAImpl(TestBase):
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
mock_tp):
mock_tp.world_size = 2
vllm_config = MagicMock()
speculative_config = MagicMock()
model_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
model_config.dtype = torch.float16
vllm_config.model_config = model_config
get_current_vllm_config.return_value = vllm_config
num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"
kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"q_lora_rank": 64,
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
}
self.impl = AscendMLAImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)
def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.q_lora_rank, 64)
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.rotary_emb)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
def test_v_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank)
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank,
self.impl.v_head_dim)
result = self.impl._v_up_proj(x)
self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1],
self.impl.num_heads * self.impl.v_head_dim)
def test_q_proj_and_k_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
self.impl.q_proj.return_value = (q_proj_output, )
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
self.impl.qk_nope_head_dim,
self.impl.kv_lora_rank)
result = self.impl._q_proj_and_k_up_proj(x)
ql_nope, q_pe = result
self.assertEqual(ql_nope.shape[0], batch_size)
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
self.assertEqual(q_pe.shape[0], batch_size)
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
def test_compute_prefill_context_none(self):
batch_size = 4
kv_cache = torch.randn(10, 1, 1, 192)
query = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
metadata = MagicMock()
metadata.prefill = None
prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8)
q_pe = query[..., self.impl.qk_nope_head_dim:]
q_nope = query[..., :self.impl.qk_nope_head_dim]
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, metadata, prefix_out,
prefix_lse)
self.assertTrue(torch.equal(prefix_out, out))
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20
query = torch.randn(S, N, D)
q_nope = query[..., :self.impl.qk_nope_head_dim]
q_pe = query[..., self.impl.qk_nope_head_dim:]
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8]
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
meta.prefill = prefill_meta
self.impl.prefill_mask = torch.triu(
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
32, meta, prefix_out,
prefix_lse)
mock_load.assert_called_once()
mock_ring.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode_without_graph(self,
mock_npu_fused_infer_attention_score,
mock_up_proj):
num_tokens = 100
block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
k_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
k_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
metadata = MagicMock()
metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(num_tokens, self.impl.num_heads,
self.impl.kv_lora_rank), None
]
mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads,
self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
block_size, metadata)
self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
magic_npu_fetch.return_value = MagicMock()
batch_size = 4
seq_len = 8
hidden_size = 1024
hidden_states = torch.randn(batch_size * seq_len, hidden_size)
kv_cache = MagicMock()
attn_metadata = MagicMock()
attn_metadata.num_decodes = 2
attn_metadata.num_prefills = 2
attn_metadata.num_decode_tokens = 2
attn_metadata.num_actual_tokens = 4
num_prefill_tokens = 2
attn_metadata.slot_mapping = torch.arange(4)
attn_metadata.decode.cos = torch.randn(2, 64)
attn_metadata.decode.sin = torch.randn(2, 64)
attn_metadata.prefill.cos = torch.randn(2, 64)
attn_metadata.prefill.sin = torch.randn(2, 64)
self.impl.q_a_proj = MagicMock()
self.impl.q_a_layernorm = MagicMock()
self.impl.q_a_layernorm.return_value = torch.randn(
attn_metadata.num_actual_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa = MagicMock()
self.impl.kv_a_proj_with_mqa.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)
]
self.impl.q_proj = MagicMock()
self.impl.q_proj.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.qk_head_dim)
]
self.impl.kv_b_proj = MagicMock()
self.impl.kv_b_proj.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
]
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
self.impl.exec_kv_decode = MagicMock()
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
self.impl.exec_kv_prefill = MagicMock()
self.impl.exec_kv_prefill.return_value = [
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim),
torch.randn(num_prefill_tokens, self.impl.num_heads,
self.impl.kv_lora_rank)
]
self.impl._q_proj_and_k_up_proj = MagicMock()
self.impl._q_proj_and_k_up_proj.return_value = [
MagicMock(), MagicMock()
]
self.impl.num_kv_heads = self.impl.num_heads
decode_res, prefill_res = self.impl._mla_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
self.assertIsNotNone(decode_res)
self.assertIsNotNone(prefill_res)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache):
B = 2
N = self.impl.num_kv_heads
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
kv_no_split = torch.randn(B, N, D)
self.impl.enable_kv_nz = None
self.impl.kv_a_layernorm.weight = MagicMock()
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
cos = MagicMock()
sin = MagicMock()
slots = MagicMock()
kv_cache = [MagicMock(), MagicMock()]
mock_kv_rmsnorm_rope_cache.return_value = [
None, None,
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
torch.randn(B, N, 1, self.impl.kv_lora_rank)
]
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin,
kv_cache, slots)
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache):
B = 2
N = self.impl.num_kv_heads
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
kv_no_split = torch.randn(B, N, D)
self.impl.enable_kv_nz = None
self.impl.kv_a_layernorm.weight = MagicMock()
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
cos = MagicMock()
sin = MagicMock()
slots = MagicMock()
kv_cache = [MagicMock(), MagicMock()]
mock_kv_rmsnorm_rope_cache.return_value = [
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None
]
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin,
kv_cache, slots)
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
@patch("torch.npu.stream")
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
@patch("torch_npu.npu_fused_infer_attention_score")
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
mock_get_multistream_comm_context,
mock_npu_stream):
B = 2
N = self.impl.num_kv_heads
BS = 100
HD = self.impl.v_head_dim
self.impl.kv_lora_rank = 256
self.impl.spec_token_num = 1
self.impl._v_up_proj = MagicMock()
self.impl._v_up_proj.return_value = torch.randn(B, N, HD)
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
k_nope = torch.randn(BS, N, self.impl.kv_lora_rank)
k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
attn_metadata.decode = MagicMock()
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
attn_metadata.decode.seq_lens_list = MagicMock()
self.impl.enable_kv_nz = True
mock_npu_fused_infer_attention_score.return_value = [
torch.randn(B, N, self.impl.kv_lora_rank), None
]
mock_get_multistream_comm_context.return_value = None
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)
self.impl.enable_kv_nz = False
attn_metadata.attn_state = None
mock_return_value = MagicMock()
mock_get_multistream_comm_context.return_value = mock_return_value
mock_return_value.before_comm_event = MagicMock()
mock_return_value.comm_stream = MagicMock()
mock_npu_stream.return_value = MagicMock()
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
attn_metadata)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], HD)

44
tests/ut/base.py Normal file
View File

@@ -0,0 +1,44 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import unittest
import pytest
from vllm_ascend.utils import adapt_patch, register_ascend_customop
class TestBase(unittest.TestCase):
def __init__(self, *args, **kwargs):
# adapt patch by default.
adapt_patch(True)
adapt_patch()
register_ascend_customop()
super().setUp()
super(TestBase, self).__init__(*args, **kwargs)
class PytestBase:
"""Base class for pytest-based tests.
because pytest mocker and parametrize usage are not compatible with unittest.
so we need to use a separate base class for pytest tests.
"""
@pytest.fixture(autouse=True)
def setup(self):
adapt_patch(True)
adapt_patch()
register_ascend_customop()

26
tests/ut/conftest.py Normal file
View File

@@ -0,0 +1,26 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from vllm_ascend.utils import adapt_patch # noqa E402
from vllm_ascend.utils import register_ascend_customop
adapt_patch()
adapt_patch(True)
# register Ascend CustomOp here because uts will use this
register_ascend_customop()

View File

@@ -0,0 +1,167 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from vllm.config import SchedulerConfig
from tests.ut.base import TestBase
from vllm_ascend.core.schedule_config import AscendSchedulerConfig
class TestAscendSchedulerConfig(TestBase):
def setUp(self):
self.basic_scheduler_config = SchedulerConfig(
max_num_batched_tokens=8192,
max_model_len=8192,
is_multimodal_model=False,
send_delta_data=False,
scheduler_delay_factor=0,
)
def test_initialize_from_config_with_default(self):
# No additional config given, check the default value here.
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, {})
self.assertEqual(ascend_config.enable_chunked_prefill, False)
self.assertEqual(ascend_config.policy, "fcfs")
self.assertEqual(ascend_config.num_scheduler_steps, 1)
self.assertEqual(ascend_config.scheduler_cls,
"vllm_ascend.core.scheduler.AscendScheduler")
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
self.assertEqual(ascend_config.encoder_cache_size, 8192)
def test_initialize_from_config_with_override(self):
# test override
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_chunked_prefill=False,
policy="fcfs",
num_scheduler_steps=1,
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertEqual(ascend_config.enable_chunked_prefill, False)
self.assertEqual(ascend_config.policy, "fcfs")
self.assertEqual(ascend_config.num_scheduler_steps, 1)
self.assertEqual(ascend_config.scheduler_cls,
"vllm_ascend.core.scheduler.AscendScheduler")
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
self.assertEqual(ascend_config.encoder_cache_size, 2048)
def test_not_implemented_policy(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
policy="custom_policy",
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler only supports fcfs policy",
str(context.exception),
)
def test_not_implemented_multimodal(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
SchedulerConfig(is_multimodal_model=True), {})
self.assertIn("currently AscendScheduler only supports LLM models",
str(context.exception))
def test_not_implemented_multi_step(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
num_scheduler_steps=2,
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler doesn't support multi-step",
str(context.exception),
)
def test_not_implemented_send_delta_data(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
send_delta_data=True,
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler doesn't support send_delta_data",
str(context.exception),
)
def test_not_implemented_delay_factor(self):
with self.assertRaises(NotImplementedError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
delay_factor=1,
max_num_batched_tokens=2048,
max_model_len=2048,
),
)
self.assertIn(
"currently AscendScheduler doesn't support scheduler_delay_factor",
str(context.exception),
)
def test_no_override(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config, {})
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
self.assertEqual(ascend_config.encoder_cache_size, 8192)
def test_valid_config_with_chunked_prefill(self):
ascend_config = AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_chunked_prefill=True,
max_num_batched_tokens=2048,
max_model_len=4096,
),
)
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
self.assertEqual(ascend_config.max_model_len, 4096)
self.assertTrue(ascend_config.enable_chunked_prefill)
def test_invalid_config_without_chunked_prefill(self):
with self.assertRaises(ValueError) as context:
AscendSchedulerConfig.initialize_from_config(
self.basic_scheduler_config,
AscendSchedulerConfig(
enable_chunked_prefill=False,
max_num_batched_tokens=2048,
max_model_len=4096,
),
)
self.assertIn(
"Ascend scheduler is enabled without chunked prefill feature",
str(context.exception),
)
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
self.assertIn("max_model_len (4096)", str(context.exception))

View File

@@ -0,0 +1,898 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import MagicMock, patch
import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from tests.ut.base import TestBase
from vllm_ascend.core.scheduler import AscendScheduler
from vllm_ascend.utils import vllm_version_is
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
from vllm.v1.outputs import DraftTokenIds
else:
DraftTokenIds = None
EOS_TOKEN_ID = 50256
MODEL = "Qwen3-0.6B"
ENABLE_PREFIX_CACHING = None
PROMPT_LOGPROBS = None
ENABLE_CHUNKED_PREFILL = False
MAX_NUM_BATCHED_TOKENS = 10000
LONG_PREFILL_TOKEN_THRESHOLD = 0
NUM_SPECULATIVE_TOKENS = None
MAX_NUM_SEQS = 16
def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[list[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[list[int]] = None,
block_size: int = 3,
hash_fn=hash,
):
init_none_hash(hash_fn)
prompt_logprobs = PROMPT_LOGPROBS
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids,
prompt_logprobs=prompt_logprobs)
requests = []
for i in range(num_requests):
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
else:
request = Request(request_id=f"{i}",
prompt_token_ids=[i] * num_tokens,
sampling_params=sampling_params,
eos_token_id=EOS_TOKEN_ID,
pooling_params=None,
block_hasher=get_request_block_hasher(
block_size, hash_fn))
requests.append(request)
return requests
def make_output(scheduler):
req_ids = [req.request_id for req in scheduler.running]
req_id_to_index = {
req.request_id: i
for i, req in enumerate(scheduler.running)
}
sampled_token_ids = [[1000]] * len(scheduler.running)
logprobs = None
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
else:
modelrunner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=logprobs,
prompt_logprobs_dict={},
pooler_output=[],
)
return modelrunner_output
class TestAscendScheduler(TestBase):
@patch("vllm.config.ModelConfig.__post_init__", MagicMock())
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
def create_scheduler(self, mock_compute_encoder_budget):
mock_compute_encoder_budget.return_value = [10, 20]
use_kv_connector = False
block_size = 16
scheduler_config = SchedulerConfig(
max_num_seqs=16,
max_model_len=MAX_NUM_BATCHED_TOKENS,
long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD,
disable_chunked_mm_input=False,
enable_chunked_prefill=ENABLE_CHUNKED_PREFILL,
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
)
scheduler_config.max_num_encoder_input_tokens = 10000
scheduler_config.encoder_cache_size = 10000
scheduler_config.chunked_prefill_enabled = False
model_config = ModelConfig(
model=MODEL,
task="auto",
tokenizer=MODEL,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="float16",
seed=42,
max_model_len=MAX_NUM_BATCHED_TOKENS,
)
model_config.pooler_config = MagicMock()
model_config.multimodal_config = MagicMock()
model_config.hf_config = MagicMock()
model_config.hf_config.is_encoder_decoder = False
# Cache config, optionally force APC
kwargs_cache: Dict[str,
Any] = ({} if ENABLE_PREFIX_CACHING is None else {
'enable_prefix_caching':
ENABLE_PREFIX_CACHING
})
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
**kwargs_cache,
)
kv_transfer_config = KVTransferConfig(
kv_connector="SharedStorageConnector",
kv_role="kv_both",
kv_connector_extra_config={"shared_storage_path": "local_storage"},
) if use_kv_connector else None
speculative_config: Optional[SpeculativeConfig] = None
if NUM_SPECULATIVE_TOKENS is not None:
speculative_config = SpeculativeConfig(
model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS)
vllm_config = VllmConfig(
scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
speculative_config=speculative_config,
)
kv_cache_config = KVCacheConfig(
num_blocks=10000, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1,
torch.float32, False))
],
)
cache_config.num_gpu_blocks = 10000
scheduler = AscendScheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=MagicMock(spec=StructuredOutputManager),
)
should_advance = MagicMock()
should_advance.return_value = False
scheduler.structured_output_manager.should_advance = should_advance
return scheduler
def test_add_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for i, request in enumerate(requests):
scheduler.add_request(request)
self.assertIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), i + 1)
def test_finish_request(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_ABORTED)
self.assertNotIn(request.request_id, scheduler.requests)
self.assertEqual(len(scheduler.waiting), 9 - i)
def test_get_num_unfinished_requests(self):
scheduler = self.create_scheduler()
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
for i, request in enumerate(requests):
scheduler.finish_requests(request.request_id,
RequestStatus.FINISHED_STOPPED)
self.assertEqual(scheduler.get_num_unfinished_requests(),
len(requests) - i - 1)
def test_schedule(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_schedule_enable_prefix_caching(self):
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
global ENABLE_PREFIX_CACHING
ENABLE_PREFIX_CACHING = True
global PROMPT_LOGPROBS
PROMPT_LOGPROBS = 5
scheduler = self.create_scheduler()
scheduler.scheduler_config.chunked_prefill_enabled = False
requests = create_requests(num_requests=10)
for request in requests:
scheduler.add_request(request)
# Test initial scheduling
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
self.assertEqual(len(output.finished_req_ids), 0)
# Verify all requests are scheduled.
for req_id, num_tokens in output.num_scheduled_tokens.items():
self.assertEqual(num_tokens,
len(requests[int(req_id)].prompt_token_ids))
# Verify requests moved from waiting to running
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), len(requests))
for i, request in enumerate(requests):
self.assertEqual(scheduler.running[i], request)
def test_stop_via_update_from_output(self):
"""Test stopping behavior through update_from_output"""
global NUM_SPECULATIVE_TOKENS
NUM_SPECULATIVE_TOKENS = 1
scheduler = self.create_scheduler()
# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [
10, 11
]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID], [
10, 11
]], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped, second continues
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID])
self.assertEqual(list(requests[1].output_token_ids), [10, 11])
# Test case 2: Stop on custom stop token
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped on custom token
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
self.assertEqual(requests[0].stop_reason, 42)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 42])
self.assertEqual(list(requests[1].output_token_ids), [13, 14])
# Test case 3: Stop on max tokens
NUM_SPECULATIVE_TOKENS = 2
scheduler = self.create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
req.status = RequestStatus.RUNNING
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify first request stopped due to length
self.assertEqual(len(scheduler.running), 1)
self.assertEqual(scheduler.running[0].request_id,
requests[1].request_id)
self.assertEqual(requests[0].status,
RequestStatus.FINISHED_LENGTH_CAPPED)
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
self.assertEqual(list(requests[0].output_token_ids), [10, 11])
self.assertEqual(list(requests[1].output_token_ids), [13])
# Test case 4: Ignore EOS flag
scheduler = self.create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
grammar_bitmask=None)
model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output, model_output)
# Verify request continues past EOS
self.assertEqual(len(scheduler.running), 1)
self.assertFalse(requests[0].is_finished())
self.assertEqual(list(requests[0].output_token_ids),
[EOS_TOKEN_ID, 10, 11])
def test_schedule_concurrent_batches(self):
global MAX_NUM_BATCHED_TOKENS
global ENABLE_PREFIX_CACHING
global ENABLE_CHUNKED_PREFILL
global MAX_NUM_SEQS
global PROMPT_LOGPROBS
ENABLE_PREFIX_CACHING = None
MAX_NUM_BATCHED_TOKENS = 1024
MAX_NUM_SEQS = 2
ENABLE_CHUNKED_PREFILL = True
PROMPT_LOGPROBS = None
enable_prefix_caching_list = [None, True]
prompt_logprobs_list = [None, 5]
for i in range(len(enable_prefix_caching_list)):
ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i]
PROMPT_LOGPROBS = prompt_logprobs_list[i]
scheduler = self.create_scheduler()
requests = create_requests(
num_requests=2,
num_tokens=512,
)
# Schedule the first request.
scheduler.add_request(requests[0])
scheduler_output0 = scheduler.schedule()
self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output0.num_scheduled_tokens[requests[0].request_id],
512)
# The first request is still running, so only schedule the second request.
scheduler.add_request(requests[1])
scheduler_output1 = scheduler.schedule()
self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1)
self.assertEqual(
scheduler_output1.num_scheduled_tokens[requests[1].request_id],
512)
# Model output of the first request.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output0,
model_runner_output)
# Schedule the next step.
# The first request can be scheduled again while the second
# request is still running.
scheduler.schedule()
# Model output of the second request.
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
scheduler.update_from_output(scheduler_output1,
model_runner_output)
def test_schedule_spec_decoding_stats(self):
"""Test scheduling behavior with speculative decoding.
This test verifies that:
1. Speculated tokens get scheduled correctly
2. Spec decoding stats properly count number of draft and accepted tokens
"""
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
[[1, 2], [3]], [[1]], [[]],
[[1, 2, 3], [4, 5, 6]]]
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
[[1, 2, 5], [3, 4]],
[[1, 2]], [[5]],
[[1, 2, 7], [4, 8]]]
expected_list: List[Tuple[int, int,
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
(1, 3, 1, [1, 0, 0]),
(2, 3, 3, [2, 1]),
(1, 1, 1, [1]),
(0, 0, 0, [0]),
(2, 6, 3, [2, 1, 0])]
global NUM_SPECULATIVE_TOKENS
for idx in range(len(spec_tokens_list)):
spec_tokens = spec_tokens_list[idx]
output_tokens = output_tokens_list[idx]
expected = expected_list[idx]
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
NUM_SPECULATIVE_TOKENS = num_spec_tokens
scheduler = self.create_scheduler()
requests = create_requests(num_requests=len(spec_tokens),
num_tokens=1)
req_ids = []
req_to_index = {}
for i, request in enumerate(requests):
scheduler.add_request(request)
req_ids.append(request.request_id)
req_to_index[request.request_id] = i
# Schedule a decode, which will also draft speculative tokens
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
self.assertEqual(output.total_num_scheduled_tokens, len(requests))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
spec_token_ids=spec_tokens,
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
# The prompt token
self.assertEqual(running_req.num_computed_tokens, 1)
# The prompt token and the sampled token
self.assertEqual(running_req.num_tokens, 2)
# The prompt token, the sampled token, and the speculated tokens
self.assertEqual(running_req.num_tokens_with_spec,
2 + len(spec_tokens[i]))
# No draft or accepted tokens counted yet
self.assertTrue(
not engine_core_outputs
or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats
is None))
# Schedule the speculated tokens for validation
output = scheduler.schedule()
self.assertEqual(len(output.scheduled_new_reqs), 0)
# The sampled token and speculated tokens
self.assertEqual(
output.total_num_scheduled_tokens,
len(requests) + sum(len(ids) for ids in spec_tokens))
for i in range(len(requests)):
req_id = requests[i].request_id
self.assertEqual(output.num_scheduled_tokens[req_id],
1 + len(spec_tokens[i]))
if spec_tokens[i]:
self.assertEqual(
len(output.scheduled_spec_decode_tokens[req_id]),
len(spec_tokens[i]))
else:
self.assertNotIn(req_id,
output.scheduled_spec_decode_tokens)
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
engine_core_outputs = scheduler.update_from_output(
output, model_runner_output)
scheduler_stats = engine_core_outputs[0].scheduler_stats \
if engine_core_outputs else None
if expected[0] == 0:
self.assertIsNone(scheduler_stats.spec_decoding_stats)
else:
self.assertIsNotNone(scheduler_stats.spec_decoding_stats)
stats = scheduler_stats.spec_decoding_stats
self.assertEqual(stats.num_drafts, expected[0])
self.assertEqual(stats.num_draft_tokens, expected[1])
self.assertEqual(stats.num_accepted_tokens, expected[2])
self.assertEqual(stats.num_accepted_tokens_per_pos,
expected[3])
def assert_scheduler_empty(self, scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
scheduler = self.create_scheduler()
self.assertEqual(len(scheduler.requests), 0)
self.assertEqual(len(scheduler.waiting), 0)
self.assertEqual(len(scheduler.running), 0)
self.assertEqual(len(scheduler.finished_req_ids), 0)
# EncoderCacheManager.
self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0)
self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0)
# KVCache Manager.
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks), 0)
self.assertEqual(
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block), 0)
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
free_block_queue.num_free_blocks)
self.assertEqual(
num_free_blocks,
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
self.assertEqual(block.ref_cnt, 0)
def test_memory_leak(self):
"""Test that we do not have a memory leak."""
scheduler = self.create_scheduler()
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
self.assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,188 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import PytestBase
from vllm_ascend.device_allocator.camem import (AllocationData, CaMemAllocator,
create_and_map,
find_loaded_library,
get_pluggable_allocator,
unmap_and_release)
def dummy_malloc(args):
pass
def dummy_free(ptr):
return (0, 0, 0, 0)
class TestCaMem(PytestBase):
def test_find_loaded_library_success_and_not_found(self):
path = find_loaded_library("libc")
assert path is not None, "Expected to find libc library"
assert path.endswith(".so.6") or ".so" in path
assert "libc" in path
path = find_loaded_library("non_existent_library")
assert path is None, "Expected to not find non-existent library"
@pytest.mark.parametrize("handle", [
(1, 2, 3),
("device", 99),
(None, ),
])
def test_create_and_map_calls_python_create_and_map(self, handle):
with patch("vllm_ascend.device_allocator.camem.python_create_and_map"
) as mock_create:
create_and_map(handle)
mock_create.assert_called_once_with(*handle)
@pytest.mark.parametrize("handle", [
(42, "bar"),
("foo", ),
])
def test_unmap_and_release_calls_python_unmap_and_release(self, handle):
with patch(
"vllm_ascend.device_allocator.camem.python_unmap_and_release"
) as mock_release:
unmap_and_release(handle)
mock_release.assert_called_once_with(*handle)
@patch("vllm_ascend.device_allocator.camem.init_module")
@patch(
"vllm_ascend.device_allocator.camem.torch.npu.memory.NPUPluggableAllocator"
)
def test_get_pluggable_allocator(self, mock_allocator_class,
mock_init_module):
mock_allocator_instance = MagicMock()
mock_allocator_class.return_value = mock_allocator_instance
def side_effect_malloc_and_free(malloc_fn, free_fn):
malloc_fn((1, 2, 3))
free_fn(123)
mock_init_module.side_effect = side_effect_malloc_and_free
allocator = get_pluggable_allocator(dummy_malloc, dummy_free)
mock_init_module.assert_called_once_with(dummy_malloc, dummy_free)
assert allocator == mock_allocator_instance
def test_singleton_behavior(self):
instance1 = CaMemAllocator.get_instance()
instance2 = CaMemAllocator.get_instance()
assert instance1 is instance2
def test_python_malloc_and_free_callback(self):
allocator = CaMemAllocator.get_instance()
# mock allocation_handle
handle = (1, 100, 1234, 0)
allocator.current_tag = "test_tag"
allocator.python_malloc_callback(handle)
# check pointer_to_data store data
ptr = handle[2]
assert ptr in allocator.pointer_to_data
data = allocator.pointer_to_data[ptr]
assert data.handle == handle
assert data.tag == "test_tag"
# check free callback with cpu_backup_tensor
data.cpu_backup_tensor = torch.zeros(1)
result_handle = allocator.python_free_callback(ptr)
assert result_handle == handle
assert ptr not in allocator.pointer_to_data
assert data.cpu_backup_tensor is None
@patch("vllm_ascend.device_allocator.camem.unmap_and_release")
@patch("vllm_ascend.device_allocator.camem.memcpy")
def test_sleep_offload_and_discard(self, mock_memcpy, mock_unmap):
allocator = CaMemAllocator.get_instance()
# prepare allocation one tag matchone not match
handle1 = (1, 10, 1000, 0)
data1 = AllocationData(handle1, "tag1")
handle2 = (2, 20, 2000, 0)
data2 = AllocationData(handle2, "tag2")
allocator.pointer_to_data = {
1000: data1,
2000: data2,
}
# mock is_pin_memory_available, return False as some machine only has cpu
with patch(
"vllm_ascend.device_allocator.camem.NPUPlatform.is_pin_memory_available",
return_value=False):
allocator.sleep(offload_tags="tag1")
# only offload tag1, other tag2 call unmap_and_release
assert data1.cpu_backup_tensor is not None
assert data2.cpu_backup_tensor is None
mock_unmap.assert_any_call(handle1)
mock_unmap.assert_any_call(handle2)
assert mock_unmap.call_count == 2
assert mock_memcpy.called
@patch("vllm_ascend.device_allocator.camem.create_and_map")
@patch("vllm_ascend.device_allocator.camem.memcpy")
def test_wake_up_loads_and_clears_cpu_backup(self, mock_memcpy,
mock_create_and_map):
allocator = CaMemAllocator.get_instance()
handle = (1, 10, 1000, 0)
tensor = torch.zeros(5, dtype=torch.uint8)
data = AllocationData(handle, "tag1", cpu_backup_tensor=tensor)
allocator.pointer_to_data = {1000: data}
allocator.wake_up(tags=["tag1"])
mock_create_and_map.assert_called_once_with(handle)
assert data.cpu_backup_tensor is None
assert mock_memcpy.called
def test_use_memory_pool_context_manager(self):
allocator = CaMemAllocator.get_instance()
old_tag = allocator.current_tag
# mock use_memory_pool_with_allocator
mock_ctx = MagicMock()
mock_ctx.__enter__.return_value = "data"
mock_ctx.__exit__.return_value = None
with patch(
"vllm_ascend.device_allocator.camem.use_memory_pool_with_allocator",
return_value=mock_ctx):
with allocator.use_memory_pool(tag="my_tag"):
assert allocator.current_tag == "my_tag"
# restore old tag after context manager exits
assert allocator.current_tag == old_tag
def test_get_current_usage(self):
allocator = CaMemAllocator.get_instance()
allocator.pointer_to_data = {
1: AllocationData((0, 100, 1, 0), "tag"),
2: AllocationData((0, 200, 2, 0), "tag"),
}
usage = allocator.get_current_usage()
assert usage == 300

View File

@@ -0,0 +1,84 @@
import os
from unittest.mock import MagicMock, patch
from vllm.distributed.utils import StatelessProcessGroup
from tests.ut.base import TestBase
from vllm_ascend.distributed.device_communicators.pyhccl import \
PyHcclCommunicator
class MockHcclLib:
pass
class MockUniqueId:
pass
class TestPyHcclCommunicator(TestBase):
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"})
def test_world_size_1_return_early(self):
comm = PyHcclCommunicator(
group=StatelessProcessGroup(0, 1, None, None),
device="npu:0",
)
self.assertTrue(comm.disabled)
self.assertFalse(comm.available)
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"})
def test_load_hccl_fail(self):
comm = PyHcclCommunicator(group=StatelessProcessGroup(
0, 2, None, None),
device="npu:0",
library_path="/not/exist/path/libhccl.so")
self.assertTrue(comm.disabled)
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
MockHcclLib)
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
MockUniqueId)
@patch("torch.npu.device")
@patch("vllm_ascend.utils.current_stream",
return_value=MagicMock(npu_stream=5678))
def test_stateless_group(self, *_):
group = StatelessProcessGroup(rank=3,
world_size=4,
store=None,
socket=None)
comm = PyHcclCommunicator(group=group, device=3)
self.assertEqual(comm.rank, 3)
self.assertEqual(comm.world_size, 4)
@patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"})
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
MockHcclLib)
@patch(
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
MockUniqueId)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="nccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.distributed.broadcast")
@patch("torch.npu.device")
@patch("vllm_ascend.utils.current_stream",
return_value=MagicMock(npu_stream=1234))
def test_multi_gpu_pg_torch(
self,
*_,
):
fake_pg = MagicMock()
comm = PyHcclCommunicator(group=fake_pg, device="npu:1")
self.assertEqual(comm.rank, 1)
self.assertEqual(comm.world_size, 2)
self.assertFalse(comm.available)
self.assertTrue(comm.disabled)

View File

@@ -0,0 +1,173 @@
from unittest.mock import MagicMock, patch
import torch
from torch.distributed import ReduceOp
from tests.ut.base import TestBase
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t,
hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum,
hcclResult_t, hcclUniqueId)
class TestHcclUniqueId(TestBase):
def test_construct(self):
uid = hcclUniqueId()
uid.internal[0] = 12
self.assertEqual(len(uid.internal), 4108)
self.assertEqual(uid.internal[0], 12)
class TestHcclDataTypeEnum(TestBase):
def test_torch_dtype_mapping(self):
expected = {
torch.int8: hcclDataTypeEnum.hcclInt8,
torch.uint8: hcclDataTypeEnum.hcclUint8,
torch.int32: hcclDataTypeEnum.hcclInt32,
torch.int64: hcclDataTypeEnum.hcclInt64,
torch.float16: hcclDataTypeEnum.hcclFloat16,
torch.float32: hcclDataTypeEnum.hcclFloat32,
torch.float64: hcclDataTypeEnum.hcclFloat64,
torch.bfloat16: hcclDataTypeEnum.hcclBfloat16,
}
for torch_dtype, expected_enum in expected.items():
with self.subTest(torch_dtype=torch_dtype):
self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype),
expected_enum)
def test_unsupported_dtype_raises(self):
with self.assertRaises(ValueError):
hcclDataTypeEnum.from_torch(torch.complex64)
class TestHcclRedOpTypeEnum(TestBase):
def test_torch_reduce_op_mapping(self):
expected = {
ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum,
ReduceOp.PRODUCT: hcclRedOpTypeEnum.hcclProd,
ReduceOp.MAX: hcclRedOpTypeEnum.hcclMax,
ReduceOp.MIN: hcclRedOpTypeEnum.hcclMin,
}
for torch_op, expected_enum in expected.items():
with self.subTest(torch_op=torch_op):
self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op),
expected_enum)
def test_unsupported_op_raises(self):
unsupported_op = "NOT_EXIST"
with self.assertRaises(ValueError):
hcclRedOpTypeEnum.from_torch(unsupported_op)
class TestFunction(TestBase):
def test_construct_with_valid_args(self):
func = Function(name="foo", restype=int, argtypes=[int, str, float])
self.assertEqual(func.name, "foo")
self.assertIs(func.restype, int)
self.assertEqual(func.argtypes, [int, str, float])
class TestHCLLLibrary(TestBase):
def test_init_with_nonexistent_so(self):
fake_path = "/definitely/not/exist/libhccl.so"
with self.assertRaises(OSError):
HCCLLibrary(fake_path)
def test_hccl_get_error_string(self):
lib = MagicMock(sepc=HCCLLibrary)
mock_fn = MagicMock()
mock_fn.return_value = "HCCL internal error"
lib.hcclGetErrorString = mock_fn
result = hcclResult_t(1)
msg = lib.hcclGetErrorString(result)
self.assertEqual(msg, "HCCL internal error")
mock_fn.assert_called_once()
def test_hccl_check(self):
lib = HCCLLibrary.__new__(HCCLLibrary)
mock_fn = MagicMock()
mock_fn.return_value = "fake error"
lib.hcclGetErrorString = mock_fn
result = hcclResult_t(123)
with self.assertRaises(RuntimeError) as cm:
lib.HCCL_CHECK(result)
self.assertEqual(str(cm.exception), "HCCL error: fake error")
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_get_uniqueId(self, mock_HCCL_CHECK):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclGetRootInfo": MagicMock(return_value=0)}
unique_id = lib.hcclGetUniqueId()
self.assertIsInstance(unique_id, hcclUniqueId)
lib._funcs["HcclGetRootInfo"].assert_called_once()
mock_HCCL_CHECK.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_comm_initRank(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclCommInitRootInfo": MagicMock(return_value=0)}
world_size = 4
unique_id = hcclUniqueId()
rank = 1
comm = lib.hcclCommInitRank(world_size, unique_id, rank)
self.assertIsInstance(comm, hcclComm_t)
lib._funcs["HcclCommInitRootInfo"].assert_called_once()
mock_hccl_check.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_all_reduce(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)}
sendbuff = buffer_type()
recvbuff = buffer_type()
count = 10
datatype = hcclDataType_t(1)
op = hcclRedOp_t(0)
comm = hcclComm_t()
stream = aclrtStream_t()
lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm,
stream)
lib._funcs["HcclAllReduce"].assert_called_once_with(
sendbuff, recvbuff, count, datatype, op, comm, stream)
mock_hccl_check.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hccl_broad_cast(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)}
buff = buffer_type()
count = 10
datatype = 1
root = 0
comm = hcclComm_t()
stream = aclrtStream_t()
lib.hcclBroadcast(buff, count, datatype, root, comm, stream)
lib._funcs["HcclBroadcast"].assert_called_once_with(
buff, count, datatype, root, comm, stream)
mock_hccl_check.assert_called_once_with(0)
@patch.object(HCCLLibrary, "HCCL_CHECK")
def test_hcclCommDestroy_success(self, mock_hccl_check):
lib = HCCLLibrary.__new__(HCCLLibrary)
lib._funcs = {"HcclCommDestroy": MagicMock(return_value=0)}
comm = hcclComm_t()
lib.hcclCommDestroy(comm)
lib._funcs["HcclCommDestroy"].assert_called_once_with(comm)
mock_hccl_check.assert_called_once_with(0)

View File

@@ -0,0 +1,89 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
import torch.distributed as dist
from vllm_ascend.distributed.communicator import NPUCommunicator
class TestNPUCommunicator(unittest.TestCase):
@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_all_to_all_with_sizes(self, *_):
def patched_all_to_all(output_tensor_list,
input_tensor_list,
group=None,
async_op=False):
output_tensor_list[:] = ([
torch.tensor([10, 20]),
torch.tensor([50, 60])
])
torch.distributed.all_to_all = patched_all_to_all
scatter_sizes = [2, 2]
gather_sizes = [2, 2]
input_ = torch.tensor([10, 20, 30, 40])
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
output = comm.all_to_all(input_,
scatter_sizes=scatter_sizes,
gather_sizes=gather_sizes)
assert output.tolist() == [10, 20, 50, 60]
@patch("vllm.config.get_current_vllm_config", return_value=None)
@patch("torch.npu.current_device", return_value=MagicMock())
@patch("torch.npu.set_device", return_value=MagicMock())
@patch("torch.distributed.get_process_group_ranks",
return_value={
0: 0,
1: 1
})
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.is_initialized", return_value=True)
@patch("torch.distributed.get_backend", return_value="hccl")
@patch("torch.distributed.get_rank", return_value=1)
@patch("torch.distributed.get_world_size", return_value=2)
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
@patch("torch.npu.device")
def test_all_to_all_without_sizes(self, *_):
def patched_all_to_all(output_tensor_list,
input_tensor_list,
group=None,
async_op=False):
output_tensor_list[:] = ([
torch.tensor([[10, 20]]),
torch.tensor([[50, 60]])
])
torch.distributed.all_to_all = patched_all_to_all
input_ = torch.tensor([[10, 20], [30, 40]])
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
assert output.tolist() == [[10, 20], [50, 60]]

View File

@@ -0,0 +1,139 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import importlib
import pytest
import torch
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend.distributed.tensor_parallel import (
_gather_along_first_dim, _gather_along_last_dim,
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
all_to_all_hp2sp, all_to_all_sp2hp)
class TestDistributedCommunication(PytestBase):
@pytest.fixture(autouse=True)
def context(self, mocker: MockerFixture):
mocker.patch("torch.npu.current_device", return_value="cpu")
mocker.patch("torch.distributed.get_world_size", return_value=4)
mocker.patch("torch.distributed.get_rank", return_value=0)
@pytest.mark.parametrize("world_size, test_tensor, expected",
[(1, torch.randn(8, 16), (8, 16)),
(4, torch.randn(8, 16), (32, 16))])
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
mocker: MockerFixture):
"""test _gather_along_first_dim"""
mocker.patch("torch.distributed.get_world_size",
return_value=world_size)
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
assert result.shape == expected
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
])
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
output_split_sizes,
mocker: MockerFixture):
"""test _gather_along_first_dim"""
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
output_split_sizes)
assert result.shape == expected
@pytest.mark.parametrize("world_size, test_tensor, expected",
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
mocker: MockerFixture):
"""test _gather_along_last_dim"""
mocker.patch("torch.distributed.get_world_size",
return_value=world_size)
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
assert result.shape == expected
@pytest.mark.parametrize("input_shape,expected_shape", [
((32, 16), (8, 16)),
((40, 10), (10, 10)),
])
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_first_dim(input_tensor,
mocker.MagicMock())
assert result.shape == expected_shape
@pytest.mark.parametrize("input_shape,expected_shape", [
((8, 16, 32), (8, 16, 8)),
])
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = _reduce_scatter_along_last_dim(input_tensor,
mocker.MagicMock())
assert result.shape == expected_shape
@pytest.mark.parametrize("func,input_shape,expected_shape", [
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
(8, 16, 128)),
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
(8, 16, 8)),
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
])
def test_wrapper_functions(self, func, input_shape, expected_shape,
mocker: MockerFixture):
"""test wrapper funcs"""
mod = importlib.import_module(
'vllm_ascend.distributed.tensor_parallel')
globals = mod.__dict__
test_func = globals[func]
input_tensor = torch.randn(*input_shape)
result = test_func(input_tensor, mocker.MagicMock())
assert result.shape == expected_shape
@pytest.mark.parametrize(
"input_shape,output_shape",
[
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
])
def test_all_to_all_sp2hp(self, input_shape, output_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
assert result.shape == output_shape
@pytest.mark.parametrize(
"input_shape,output_shape",
[
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
])
def test_all_to_all_hp2sp(self, input_shape, output_shape,
mocker: MockerFixture):
input_tensor = torch.randn(*input_shape)
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
assert result.shape == output_shape

View File

@@ -0,0 +1,44 @@
from unittest.mock import MagicMock, patch
import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
get_mc2_group, init_ascend_model_parallel)
@pytest.fixture
def parallel_config():
return ParallelConfig(data_parallel_size=2,
tensor_parallel_size=2,
pipeline_parallel_size=2)
@pytest.fixture
def mock_distributed():
with patch('torch.distributed.is_initialized', return_value=True), \
patch('torch.distributed.get_world_size', return_value=8), \
patch('torch.distributed.get_backend', return_value='nccl'), \
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
mock_group.return_value.local_rank = 0
mock_group.return_value.device_group = MagicMock()
yield
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
init_ascend_model_parallel(parallel_config)
mc2_group = get_mc2_group()
assert mc2_group is not None
lmheadtp_group = get_lmhead_tp_group()
assert lmheadtp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None

View File

@@ -0,0 +1,28 @@
{
"_name_or_path": "facebook/opt-125m",
"activation_dropout": 0.0,
"activation_function": "relu",
"architectures": [
"OPTForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 2,
"do_layer_norm_before": true,
"dropout": 0.1,
"eos_token_id": 2,
"ffn_dim": 3072,
"hidden_size": 768,
"init_std": 0.02,
"layerdrop": 0.0,
"max_position_embeddings": 2048,
"model_type": "opt",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"prefix": "</s>",
"torch_dtype": "float16",
"transformers_version": "4.21.0.dev0",
"use_cache": true,
"vocab_size": 50272,
"word_embed_proj_dim": 768
}

View File

@@ -0,0 +1,96 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
import types
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
create_vllm_config)
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import (
LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole)
def test_basic_inferface():
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
request_id = request.request_id
scheduler.add_request(request)
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
scheduler_output = scheduler.schedule()
kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
assert len(kv_connector_metadata.requests) == 1
assert request_id in kv_connector_metadata.requests
req_meta = kv_connector_metadata.requests[request_id]
for block_id, block in zip(
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id]):
assert block_id == block.block_id
def test_read_agent_metadata():
rank_table = {
"version":
"1.2",
"server_count":
"2",
"prefill_device_list": [{
"server_id": "192.168.1.1",
"device_id": "0",
"device_ip": "10.30.0.1",
"cluster_id": "0",
}, {
"server_id": "192.168.1.1",
"device_id": "1",
"device_ip": "10.30.0.2",
"cluster_id": "1",
}, {
"server_id": "192.168.1.2",
"device_id": "0",
"device_ip": "10.30.0.3",
"cluster_id": "2",
}, {
"server_id": "192.168.1.2",
"device_id": "1",
"device_ip": "10.30.0.4",
"cluster_id": "3",
}]
}
def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices):
old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
worker = types.SimpleNamespace()
worker.local_ip = worker_local_ip
worker.tp_rank = worker_tp_rank
worker.llm_datadist_role = LLMRole.PROMPT
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
worker, rank_table)
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices
return agent_metadata.device_ip
assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2"
assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3"
assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4"
assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2"
assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1"
assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2"

View File

@@ -0,0 +1,998 @@
import os
import queue
import socket
import sys
import threading
import time
import types
import unittest
from collections import defaultdict, deque
from unittest.mock import MagicMock, patch
import msgspec
import zmq
from vllm.utils import make_zmq_path
fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
sys.modules["mooncake.engine"] = fake_engine
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
MooncakeConnectorMetadata, MooncakeConnectorScheduler,
MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send,
group_concurrent_contiguous, string_to_int64_hash, zmq_ctx)
GET_META_MSG = b"get_meta_msg"
DONE_RECVING_MSG = b"done_recving_msg"
class TestKVCacheTaskTrackerInit(unittest.TestCase):
def test_init_basic_properties(self):
tracker = KVCacheTaskTracker()
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
self.assertIsInstance(tracker.finished_requests, set)
self.assertIsInstance(tracker.delayed_free_requests, deque)
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker()
self.tracker.finished_requests = set()
self.tracker.done_task_lock = threading.Lock()
def test_empty_requests(self):
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, set())
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_single_request(self):
self.tracker.finished_requests = {"req_123"}
result = self.tracker.get_and_clear_finished_requests()
self.assertEqual(result, {"req_123"})
self.assertEqual(len(self.tracker.finished_requests), 0)
def test_multiple_requests(self):
self.tracker.finished_requests = {"req_1", "req_2", "req_3"}
result = self.tracker.get_and_clear_finished_requests()
self.assertSetEqual(result, {"req_1", "req_2", "req_3"})
self.assertEqual(len(self.tracker.finished_requests), 0)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_concurrent_access(self, mock_logger):
from concurrent.futures import ThreadPoolExecutor
self.tracker.finished_requests = {"req_1", "req_2"}
with ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(self.tracker.get_and_clear_finished_requests)
for _ in range(3)
]
results = [f.result() for f in futures]
self.assertEqual(sum(1 for r in results if r), 1)
self.assertEqual(len(self.tracker.finished_requests), 0)
class TestKVCacheSendingThreadInit(unittest.TestCase):
def setUp(self):
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': MagicMock(),
'ready_event': threading.Event()
}
self.threads = []
def tearDown(self):
for thread in self.threads:
if hasattr(thread, 'task_tracker') and hasattr(
thread.task_tracker, 'socket'):
thread.task_tracker.socket.close()
if hasattr(thread, 'is_alive') and thread.is_alive():
thread.join(timeout=0.1)
def test_thread_daemon_property(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertTrue(thread.daemon)
def test_thread_name_format(self):
thread = KVCacheSendingThread(**self.common_args)
self.threads.append(thread)
self.assertEqual(thread.name, "KVCacheSendingThread")
def test_ready_event_reference(self):
custom_event = threading.Event()
args = self.common_args.copy()
args['ready_event'] = custom_event
thread = KVCacheSendingThread(**args)
self.threads.append(thread)
self.assertIs(thread.ready_event, custom_event)
class TestGetAndClearFinishedRequests(unittest.TestCase):
def setUp(self):
self.common_args = {
'tp_rank': 1,
'decode_tp_size': 4,
'local_engine_id': 'engine_1',
'side_channel_host': 'localhost',
'side_channel_port': 5555,
'metadata': {
"test": "metadata"
},
'ready_event': threading.Event()
}
self.thread = KVCacheSendingThread(**self.common_args)
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
def test_get_and_clear_finished_requests(self, mock_get_clear):
expected_requests = {'req1', 'req2'}
mock_get_clear.return_value = expected_requests
result = self.thread.get_and_clear_finished_requests()
mock_get_clear.assert_called_once()
self.assertEqual(result, expected_requests)
class TestKVCacheSendingThread(unittest.TestCase):
def test_run_handles_get_meta_and_done_recv_msgs(self):
ready_event = threading.Event()
metadata = MooncakeAgentMetadata(
engine_id="engine1",
te_rpc_port=9090,
kv_caches_base_addr=[12345678],
num_blocks=2,
)
host = "127.0.0.1"
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
free_port = s.getsockname()[1]
thread = KVCacheSendingThread(
tp_rank=0,
decode_tp_size=1,
local_engine_id="engine1",
side_channel_host=host,
side_channel_port=free_port,
metadata=metadata,
ready_event=ready_event,
)
thread.start()
self.assertTrue(ready_event.wait(timeout=3),
"Server thread startup timeout")
context = zmq.Context() # type: ignore
sock = context.socket(zmq.DEALER) # type: ignore
sock.connect(f"tcp://{host}:{free_port}")
encoder = msgspec.msgpack.Encoder()
decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata)
sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))])
frames = sock.recv_multipart()
self.assertEqual(frames[0], b"")
meta = decoder.decode(frames[1])
self.assertEqual(meta.engine_id, "engine1")
self.assertEqual(meta.kv_caches_base_addr, [12345678])
self.assertEqual(meta.num_blocks, 2)
req_id = "request_42"
sock.send_multipart(
[b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))])
frames = sock.recv_multipart()
self.assertEqual(frames[0], b"")
self.assertEqual(frames[1], b"ACK")
self.assertIn(req_id, thread.task_tracker.finished_requests)
sock.close()
context.term()
class TestKVCacheRecvingThreadBasic(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
def test_add_request(self):
test_req = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
}
self.thread.add_request(**test_req)
queued = self.thread.request_queue.get_nowait()
self.assertEqual(queued["request_id"], "req1")
self.assertEqual(queued["remote_host"], "localhost")
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
def test_get_finished_requests(self, mock_tracker):
mock_tracker.return_value = {"req1", "req2"}
result = self.thread.get_and_clear_finished_requests()
self.assertEqual(result, {"req1", "req2"})
class TestSocketManagement(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.remote_sockets = defaultdict(deque)
self.thread.remote_poller = MagicMock()
@patch('vllm_ascend.distributed.mooncake_connector.zmq.Context')
@patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket')
def test_get_remote_socket(self, mock_make_socket, mock_context):
mock_sock = MagicMock()
mock_make_socket.return_value = mock_sock
test_host = "test_host"
test_port = 12345
sock = self.thread._get_remote_socket(test_host, test_port)
self.assertEqual(sock, mock_sock)
mock_make_socket.assert_called_once()
args, kwargs = mock_make_socket.call_args
self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345')
self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore
self.assertFalse(kwargs.get('bind', True))
self.thread.remote_poller.register.assert_called_with(
mock_sock, zmq.POLLIN) # type: ignore
def test_return_socket_to_pool(self):
mock_sock = MagicMock()
test_host = "test_host"
test_port = 12345
test_path = make_zmq_path("tcp", test_host, test_port)
self.thread._return_remote_socket(mock_sock, test_host, test_port)
self.assertEqual(len(self.thread.remote_sockets[test_path]), 1)
self.assertEqual(self.thread.remote_sockets[test_path][0], mock_sock)
self.thread.remote_poller.register.assert_not_called()
class TestCoreFunctionality(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.mock_queue = MagicMock()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.request_queue = self.mock_queue
self.test_req = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
}
self.thread.task_tracker = MagicMock()
self.engine.batch_transfer_sync_read.return_value = 0
self.thread.remote_te_port = {"remote_engine": {6666: 7777}}
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
def test_handle_request(self, mock_send, mock_transfer):
self.thread._handle_request(self.test_req)
mock_transfer.assert_called_once_with(self.test_req)
mock_send.assert_called_once_with("req1", "localhost", 6666)
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
"req1")
self.mock_queue.task_done.assert_called_once()
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
def test_transfer_kv_cache(self, mock_get_meta):
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
self.thread._transfer_kv_cache(self.test_req)
self.engine.batch_transfer_sync_read.assert_called_once()
call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args
self.assertEqual(call_args[0], "localhost:7777")
self.assertIsInstance(call_args[1], list)
self.assertIsInstance(call_args[2], list)
self.assertIsInstance(call_args[3], list)
self.assertEqual(len(call_args[1]), len(call_args[2]))
self.assertEqual(len(call_args[1]), len(call_args[3]))
mock_get_meta.assert_not_called()
def test_transfer_kv_cache_failure(self):
self.engine.batch_transfer_sync_read.return_value = -1
self.thread.kv_caches_base_addr["remote_engine"] = {
6666: [0x3000, 0x4000]
}
with self.assertRaises(RuntimeError):
self.thread._transfer_kv_cache(self.test_req)
class TestMetadataHandling(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.test_metadata = MooncakeAgentMetadata(
engine_id="remote_engine",
te_rpc_port=9090,
kv_caches_base_addr=[0x3000, 0x4000],
num_blocks=2)
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv')
def test_get_remote_metadata_success(self, mock_recv, mock_send):
mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata)
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
mock_socket = MagicMock()
mock_get_socket.return_value = mock_socket
self.thread._get_remote_metadata("host1", 5555)
mock_get_socket.assert_called_once_with("host1", 5555)
mock_return_socket.assert_called_once_with(mock_socket, "host1",
5555)
mock_send.assert_called_once_with(
mock_socket, self.thread.encoder.encode((GET_META_MSG, "")))
mock_recv.assert_called_once_with(mock_socket,
self.thread.remote_poller)
self.assertEqual(
self.thread.kv_caches_base_addr["remote_engine"][5555],
[0x3000, 0x4000])
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv',
side_effect=Exception("Network error"))
def test_get_remote_metadata_failure(self, mock_recv, mock_send):
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
mock_socket = MagicMock()
mock_get_socket.return_value = mock_socket
with self.assertRaises(Exception) as context:
self.thread._get_remote_metadata("host1", 5555)
self.assertEqual(str(context.exception), "Network error")
mock_return_socket.assert_called_once()
class TestMainThreadLoop(unittest.TestCase):
def setUp(self):
self.engine = MagicMock()
self.ready_event = threading.Event()
self.thread = KVCacheRecvingThread(
tp_rank=0,
tp_size=4,
engine=self.engine,
local_engine_id="local_engine",
local_handshake_port=5555,
local_kv_caches_base_addr=[0x1000, 0x2000],
block_len=[1024, 2048],
ready_event=self.ready_event)
self.thread.request_queue = queue.Queue()
@patch.object(KVCacheRecvingThread, '_handle_request')
def test_run_loop_normal(self, mock_handle):
test_request = {
"request_id": "req1",
"local_block_ids": [1, 2],
"remote_block_ids": [3, 4],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_handshake_port": 6666,
"remote_transfer_port": 7777
}
self.thread.request_queue.put(test_request)
self.thread.request_queue.put(None)
self.thread.start()
time.sleep(0.1)
self.thread.join(timeout=1.0)
self.assertTrue(self.thread.ready_event.is_set())
mock_handle.assert_called_once_with(test_request)
self.assertTrue(self.thread.request_queue.empty())
class MockVllmConfig:
def __init__(self):
self.model_config = MagicMock()
self.parallel_config = MagicMock()
self.cache_config = MagicMock()
self.kv_transfer_config = MagicMock()
self.model_config.use_mla = True
self.parallel_config.tensor_parallel_size = 2
self.parallel_config.data_parallel_rank_local = 0
self.parallel_config.data_parallel_size_local = 1
self.cache_config.block_size = 16
self.kv_transfer_config.kv_port = 5000
self.kv_transfer_config.kv_role = 'kv_producer'
self.kv_transfer_config.get_from_extra_config = MagicMock()
self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: {
"prefill": {
"tp_size": 2,
"dp_size": 1
},
"decode": {
"tp_size": 2,
"dp_size": 1
}
}.get(k, d)
class MockRequest:
def __init__(self,
request_id,
prompt_token_ids=None,
kv_transfer_params=None,
status=None):
self.request_id = request_id
self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4]
self.kv_transfer_params = kv_transfer_params or {}
self.status = status or "running"
self.output_token_ids = [101, 102]
class TestKVCacheTaskTracker(unittest.TestCase):
def setUp(self):
self.tracker = KVCacheTaskTracker()
def test_update_done_task_count(self):
self.assertEqual(len(self.tracker.finished_requests), 0)
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time)
result = self.tracker.delayed_free_requests
self.assertEqual(len(result), 1)
self.assertEqual(result[0], ("req_1", current_time))
self.tracker.update_done_task_count("req_1")
result_finished = self.tracker.finished_requests
result_delayed = self.tracker.delayed_free_requests
self.assertEqual(result_finished, {"req_1"})
self.assertEqual(len(result_delayed), 0)
def test_retrieve_expired_requests(self):
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time - 600)
self.tracker.add_delayed_request("req_2", current_time)
result = self.tracker._retrieve_expired_requests()
self.assertEqual(result, {
"req_1",
})
result_delay = self.tracker.delayed_free_requests
self.assertEqual(len(result_delay), 1)
self.assertEqual(result_delay[0], ("req_2", current_time))
def test_duplicate_task_update(self):
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
self.tracker.update_done_task_count("req1")
finished = self.tracker.get_and_clear_finished_requests()
self.assertEqual(finished, {"req1"})
class TestMooncakeConnectorMetadata(unittest.TestCase):
def test_add_new_req(self):
meta = MooncakeConnectorMetadata()
self.assertEqual(len(meta.requests), 0)
self.assertEqual(len(meta.requests_to_send), 0)
meta.add_new_req(request_id="req1",
local_block_ids=[1, 2, 3],
kv_transfer_params={
"remote_block_ids": [4, 5, 6],
"remote_engine_id": "remote_engine",
"remote_host": "localhost",
"remote_port": 5000
})
self.assertEqual(len(meta.requests), 1)
req_meta = meta.requests["req1"]
self.assertIsInstance(req_meta, ReqMeta)
self.assertEqual(req_meta.local_block_ids, [1, 2, 3])
self.assertEqual(req_meta.remote_block_ids, [4, 5, 6])
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
self.assertEqual(req_meta.remote_host, "localhost")
self.assertEqual(req_meta.remote_port, 5000)
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
def setUp(self):
config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
def test_get_num_new_matched_tokens(self):
request = MockRequest("req1")
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 0)
self.assertFalse(async_flag)
request.kv_transfer_params = {"do_remote_prefill": True}
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 3)
self.assertTrue(async_flag)
def test_build_connector_meta(self):
request = MockRequest("req1")
blocks_mock = MagicMock()
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
request.kv_transfer_params = {
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
}
meta = self.scheduler.build_connector_meta(MagicMock())
self.assertIsInstance(meta, MooncakeConnectorMetadata)
self.assertEqual(len(meta.requests), 1)
self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6])
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_get_finished_count(self):
count = self.scheduler.get_finished_count()
self.assertEqual(count, 2)
class TestHelperFunctions(unittest.TestCase):
def test_group_concurrent_contiguous(self):
src: list[int] = [1, 2, 3, 5, 6]
dst: list[int] = [10, 11, 12, 14, 15]
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
self.assertEqual(len(src_groups), 2)
self.assertEqual(src_groups[0], [1, 2, 3])
self.assertEqual(src_groups[1], [5, 6])
self.assertEqual(dst_groups[0], [10, 11, 12])
self.assertEqual(dst_groups[1], [14, 15])
def test_group_concurrent_contiguous_empty(self):
src: list[int] = []
dst: list[int] = []
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
self.assertEqual(src_groups, [])
self.assertEqual(dst_groups, [])
def test_string_to_int64_hash(self):
hash1 = string_to_int64_hash("test_string")
hash2 = string_to_int64_hash("test_string")
self.assertEqual(hash1, hash2)
hash3 = string_to_int64_hash("different_string")
self.assertNotEqual(hash1, hash3)
class TestMooncakeConnectorForScheduler(unittest.TestCase):
def test_scheduler_role(self):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_scheduler_methods(self, mock_method):
config = MockVllmConfig()
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
class MockKVCacheBlocks:
def get_unhashed_block_ids(self):
return [4, 5, 6]
class MockSchedulerOutput:
pass
class MockForwardContext:
pass
class TestMooncakeConnector(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
def test_scheduler_initialization(self):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
self.assertIsNotNone(connector.connector_scheduler)
self.assertIsNone(connector.connector_worker)
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
def test_get_num_new_matched_tokens(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.get_num_new_matched_tokens(request, 0)
mock_method.assert_called_once_with(request, 0)
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
def test_update_state_after_alloc(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
blocks = MockKVCacheBlocks()
connector.update_state_after_alloc(request, blocks, 3)
mock_method.assert_called_once_with(request, blocks, 3)
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
def test_build_connector_meta(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
scheduler_output = MockSchedulerOutput()
connector.build_connector_meta(scheduler_output)
mock_method.assert_called_once_with(scheduler_output)
@patch.object(MooncakeConnectorScheduler, "request_finished")
def test_request_finished(self, mock_method):
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
request = MockRequest("req1")
connector.request_finished(request, [1, 2, 3])
mock_method.assert_called_once_with(request, [1, 2, 3])
class TestMooncakeConnectorScheduler(unittest.TestCase):
def setUp(self):
self.config = MockVllmConfig()
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
def test_get_num_new_matched_tokens_no_remote_prefill(self):
request = MockRequest("req1")
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 0)
self.assertFalse(async_flag)
def test_get_num_new_matched_tokens_with_remote_prefill(self):
request = MockRequest("req1",
kv_transfer_params={"do_remote_prefill": True})
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
request, 0)
self.assertEqual(tokens, 3)
self.assertTrue(async_flag)
def test_update_state_after_alloc_no_remote_prefill(self):
request = MockRequest("req1")
blocks = MagicMock()
self.scheduler.update_state_after_alloc(request, blocks, 0)
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
def test_update_state_after_alloc_with_remote_prefill(self):
request = MockRequest("req1",
kv_transfer_params={
"do_remote_prefill": True,
"remote_block_ids": [1, 2, 3],
"remote_engine_id": "remote",
"remote_host": "localhost",
"remote_port": 5000
})
blocks = MockKVCacheBlocks()
self.scheduler.update_state_after_alloc(request, blocks, 3)
self.assertEqual(len(self.scheduler._reqs_need_recv), 1)
self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request)
self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6])
def test_request_finished_no_remote_decode(self):
request = MockRequest("req1")
delay_free, params = self.scheduler.request_finished(
request, [1, 2, 3])
self.assertFalse(delay_free)
self.assertIsNone(params)
class TestUtils(unittest.TestCase):
def test_string_to_int64_hash(self):
h1 = string_to_int64_hash("hello")
h2 = string_to_int64_hash("hello")
h3 = string_to_int64_hash("world")
self.assertEqual(h1, h2)
self.assertNotEqual(h1, h3)
self.assertIsInstance(h1, int)
def test_group_concurrent_contiguous(self):
src: list[int] = [1, 2, 3, 5, 6]
dst: list[int] = [10, 11, 12, 20, 21]
src_g, dst_g = group_concurrent_contiguous(src, dst)
self.assertEqual(src_g, [[1, 2, 3], [5, 6]])
self.assertEqual(dst_g, [[10, 11, 12], [20, 21]])
def test_group_empty(self):
src_g, dst_g = group_concurrent_contiguous([], [])
self.assertEqual(src_g, [])
self.assertEqual(dst_g, [])
def test_zmq_ctx_invalid_type(self):
with self.assertRaises(ValueError):
with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"):
pass
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
def test_zmq_ctx_ok(self, mock_make_socket):
mock_socket = MagicMock()
mock_make_socket.return_value = mock_socket
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
self.assertEqual(s, mock_socket)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_send_success(self, mock_logger):
mock_socket = MagicMock()
ensure_zmq_send(mock_socket, b"hello")
mock_socket.send.assert_called_once_with(b"hello")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_send_retry_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
"send failed")
with self.assertRaises(RuntimeError):
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
self.assertEqual(mock_socket.send.call_count, 2)
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_recv_success(self, mock_logger):
mock_socket = MagicMock()
mock_socket.recv.return_value = b"response"
mock_poller = MagicMock()
mock_poller.poll.return_value = [
(mock_socket, zmq.POLLIN) # type: ignore
]
data = ensure_zmq_recv(mock_socket, mock_poller)
self.assertEqual(data, b"response")
@patch("vllm_ascend.distributed.mooncake_connector.logger")
def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger):
mock_socket = MagicMock()
mock_poller = MagicMock()
mock_poller.poll.return_value = []
with self.assertRaises(RuntimeError):
ensure_zmq_recv(mock_socket,
mock_poller,
timeout=0.01,
max_retries=2)
class MockMooncakeAgentMetadata:
def __init__(self, **kwargs):
pass
class MockMooncakeConnectorMetadata:
def __init__(self):
self.requests = {}
class MockKVCacheSendingThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__()
self.daemon = True
self._finished_requests = set()
def get_and_clear_finished_requests(self):
return self._finished_requests
def start(self):
pass
class MockKVCacheRecvingThread(threading.Thread):
def __init__(self, *args, **kwargs):
super().__init__()
self.daemon = True
self._finished_requests = set()
self.add_request = MagicMock()
def get_and_clear_finished_requests(self):
return self._finished_requests
def start(self):
pass
class MockTensor:
def __init__(self, *args, **kwargs):
self.size = MagicMock(return_value=(10, 16, 8, 16))
self.element_size = MagicMock(return_value=4)
self.shape = (10, 16, 8, 16)
self.data_ptr = MagicMock(return_value=0x1000)
mock_envs_ascend = MagicMock()
mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
mock_logger = MagicMock()
class MockTransferEngine:
def initialize(self, *args, **kwargs):
return 0
def register_memory(self, *args, **kwargs):
return 1
class MockEnvsAscend:
MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
PHYSICAL_DEVICES = "10,11"
def mock_get_tensor_model_parallel_rank():
return 0
def mock_get_tp_group():
return MagicMock()
def mock_get_ip():
return "127.0.0.1"
def mock_string_to_int64_hash(s):
return hash(s)
class TestMooncakeConnectorWorker(unittest.TestCase):
def setUp(self):
self.envs_ascend_mock = MockEnvsAscend()
self.mock_transfer_engine = MagicMock()
self.mock_transfer_engine.get_rpc_port.return_value = 9090
self.mock_transfer_engine.initialize.return_value = 0
self.mock_transfer_engine.register_memory.return_value = 0
self.patches = [
patch('os.getenv', return_value="10,11"),
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
patch('torch.Tensor.element_size', return_value=4),
patch('torch.Tensor.data_ptr', return_value=0x1000),
patch('math.prod', return_value=128),
patch('random.Random'),
patch(
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
mock_get_tensor_model_parallel_rank),
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
mock_get_tp_group),
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
mock_get_ip),
patch(
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
mock_string_to_int64_hash),
patch('vllm_ascend.distributed.mooncake_connector.TransferEngine',
return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
MagicMock()),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.logger',
MagicMock()),
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
MagicMock()),
patch.dict('sys.modules',
{'vllm_ascend.envs': self.envs_ascend_mock}),
]
for p in self.patches:
p.start() # type: ignore
self.vllm_config = MockVllmConfig()
self.engine_id = "test_engine"
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
def tearDown(self):
for p in self.patches:
p.stop() # type: ignore
def test_register_kv_caches_producer(self):
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches)
self.assertEqual(len(worker.kv_caches), 1)
self.assertIsNotNone(worker.kv_send_thread)
self.assertIsNone(worker.kv_recv_thread)
def test_register_kv_caches_consumer(self):
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(self.kv_caches)
self.assertIsNone(worker.kv_send_thread)
self.assertIsNotNone(worker.kv_recv_thread)
def test_register_kv_caches_mla_case(self):
mla_cache1 = MagicMock()
mla_cache1.size.return_value = (10, 16, 1, 16)
mla_cache2 = MagicMock()
mla_cache2.size.return_value = (10, 16, 1, 8)
mla_caches = {"layer1": (mla_cache1, mla_cache2)}
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
worker.register_kv_caches(mla_caches)
self.assertTrue(worker.use_mla)
self.assertEqual(len(worker.block_len), 2)
def test_device_id_selection_with_physical_devices(self):
# Test with physical devices set
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
# Default tp_rank is 0, so device_id should be 10
self.assertEqual(worker.device_id, 10)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,169 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import FinishReason, RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a Remote Decode request."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request = create_request(request_id=1,
max_tokens=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Prefill.
# (1a): schedule()
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
assert len(scheduler_output.scheduled_new_reqs) == 1
# (1b): execute_model()
model_runner_output = create_model_runner_output(reqs=[request])
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
# Ensure the request is finished after 1 tokens.
assert request.is_finished()
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
output = engine_core_outputs[0].outputs[0]
assert output.finish_reason == FinishReason.LENGTH
assert output.kv_transfer_params is not None
# Request freed in Scheduler and blocks should be freed
assert request_id in scheduler.finished_req_ids
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 0
# ... but blocks should not be freed.
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 1
assert request_id in scheduler_output.finished_req_ids
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (2b): execute_model()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (2c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# STEP (3): Finished sending.
# (3a): schedule() - pass finished request to PB.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 0
assert len(scheduler_output.finished_req_ids) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler.finished_req_ids) == 0
# (3b): execute_model()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_id])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm we do not have any memory leaks after req lifecycle.
assert_scheduler_empty(scheduler)
def test_prefix_cache_lifecycle():
"""Test that remote decode params still works with a prefix cache hit."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# Prime the KVCache.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 3
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
scheduler.add_request(request_remote_a)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
use_eos=True)
scheduler.update_from_output(scheduler_output, model_runner_output)
scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
#####################
# Actual Test: confirm we send all blocks.
# Step (1): Send the KV Transfer.
NUM_EXTERNAL_FULL_BLOCKS -= 1
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
request_remote = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_decode=True)
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output(reqs=[request_remote])
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
# Ensure we send all block ids, even if there is a cache hit.
assert (len(
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
1))
# STEP (2): Ensure it is freed.
scheduler_output = scheduler.schedule()
scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_sending=[request_remote.request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
_ = scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,239 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import copy
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
from vllm.v1.request import RequestStatus
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
create_model_runner_output,
create_request, create_scheduler,
create_vllm_config)
def test_basic_lifecycle():
"""Test lifecycle of a remote prefill."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
START_FREE_BLOCK_QUEUE_SIZE = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
block_size=BLOCK_SIZE)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1):
# (1a): schedule()
scheduler_output = scheduler.schedule()
# Nothing running and empty scheduler output.
assert len(scheduler.running) == 0
assert len(scheduler_output.scheduled_new_reqs) == 0
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
assert len(scheduler_output.num_scheduled_tokens) == 0
assert scheduler_output.total_num_scheduled_tokens == 0
# Req waiting for KVs with no computed/scheduled toks ...
assert len(scheduler.waiting) == 1
assert request in scheduler.waiting
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
assert (request.num_computed_tokens == 0)
# ... but should have (uncached) blocks allocated to it.
block_pool = scheduler.kv_cache_manager.block_pool
assert (block_pool.free_block_queue.num_free_blocks
< START_FREE_BLOCK_QUEUE_SIZE)
assert len(block_pool.cached_block_hash_to_block) == 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block._block_hash is None
# (1b): forward()
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
# (1c): update_from_output()
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert not engine_core_outputs or not engine_core_outputs[0].outputs
# STEP (2):
# (2a): schedule(): nothing happens!
scheduler_output = scheduler.schedule()
assert len(scheduler.waiting) == 1
assert len(scheduler.running) == 0
# (2b): forward(): request finishes recv.
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
# (2c): update_from_output():
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# STEP (3):
# (3a): schedule(): this should actually schedule.
scheduler_output = scheduler.schedule()
assert len(scheduler.running) == 1
# Confirm the block are actually allocated.
num_hashed_blocks = 0
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_id]
for block in blocks:
assert block.ref_cnt == 1
num_hashed_blocks += (1 if block._block_hash is not None else 0)
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
# Confirm the rest of the prompt is scheduled in this step.
scheduled_req = scheduler_output.scheduled_new_reqs[0]
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
num_computed_tokens = scheduled_req.num_computed_tokens
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
# (3b): execute_model()
model_runner_output = create_model_runner_output([request])
# (3c): update_from_output()
scheduler.update_from_output(scheduler_output, model_runner_output)
# Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
engine_core_outputs = scheduler.update_from_output(scheduler_output,
model_runner_output)
scheduler.schedule()
assert_scheduler_empty(scheduler)
def test_no_spurious_prefix_caching():
"""
With P/D, blocks can be allocated but uncomputed for
multiple engine steps. This test confirms that we do
not accidentally have cache hits against uncomputed
blocks.
"""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 and a half full external blocks.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote = create_request(
request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True,
use_all_1s_for_prompt_tokens=True,
)
# Schedule the remote prefill request. This should not
# cause any blocks to be cached.
scheduler.add_request(request_remote)
scheduler_output = scheduler.schedule()
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
assert len(scheduler.waiting) == 1
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
0].req_to_blocks[request_remote.request_id]
# Remote blocks should not be cached.
for block in remote_blocks:
assert block.ref_cnt == 1
assert block._block_hash is None
def test_full_block_prompt():
"""Test that we handle a prompt that is the full block size."""
vllm_config = create_vllm_config()
scheduler = create_scheduler(vllm_config)
# 2 Full Blocks and 1 Half Block.
BLOCK_SIZE = vllm_config.cache_config.block_size
NUM_EXTERNAL_FULL_BLOCKS = 2
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
request = create_request(request_id=1,
num_tokens=NUM_TOKENS,
do_remote_prefill=True)
scheduler.add_request(request)
request_id = request.request_id
# STEP (1): Initialize a recv.
scheduler_output = scheduler.schedule()
# All blocks should be allocated.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
scheduler.update_from_output(scheduler_output, model_runner_output)
# # STEP (2): Recv.
scheduler_output = scheduler.schedule()
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
model_runner_output.kv_connector_output = KVConnectorOutput(
finished_recving=[request_id])
scheduler.update_from_output(scheduler_output, model_runner_output)
assert len(scheduler.waiting) == 1
assert (request_id in scheduler.finished_recving_kv_req_ids)
# # STEP (3): Run as usual.
scheduler_output = scheduler.schedule()
# We need to recompute the final token of the prompt to generate
# the first new token, so we should not have a new block.
num_blocks = len(scheduler.kv_cache_manager.coordinator.
single_type_managers[0].req_to_blocks[request_id])
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
NUM_TOKENS - 1)
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
model_runner_output = create_model_runner_output([request])
scheduler.update_from_output(scheduler_output, model_runner_output)
# # Step (4): Hit EOS.
scheduler_output = scheduler.schedule()
model_runner_output = create_model_runner_output([request], use_eos=True)
scheduler.schedule()
assert_scheduler_empty(scheduler)

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
import os
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
ModelConfig, SchedulerConfig, VllmConfig)
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
init_none_hash)
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request
from vllm.v1.structured_output import StructuredOutputManager
from vllm_ascend.utils import vllm_version_is
EOS_TOKEN_ID = 50256
os.environ["VLLM_USE_V1"] = "1"
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler.finished_recving_kv_req_ids) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
def create_vllm_config(
max_num_seqs: int = 16,
max_num_batched_tokens: int = 1024,
block_size: int = 128,
) -> VllmConfig:
"""Initialize VllmConfig For Testing."""
scheduler_config = SchedulerConfig(
max_num_seqs=max_num_seqs,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_num_batched_tokens,
)
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
"fake_weight")
model_config = ModelConfig(
model=fake_weight_path,
skip_tokenizer_init=True,
)
# Cache config, optionally force APC
cache_config = CacheConfig(
block_size=block_size,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
enable_prefix_caching=True,
)
kv_transfer_config = KVTransferConfig(
kv_connector="LLMDataDistCMgrConnector",
kv_role="kv_both",
kv_connector_module_path=
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
return VllmConfig(scheduler_config=scheduler_config,
model_config=model_config,
cache_config=cache_config,
kv_transfer_config=kv_transfer_config,
device_config=DeviceConfig("cpu"))
def create_scheduler(
vllm_config: VllmConfig,
num_blocks: int = 10000,
) -> Scheduler:
"""Initialize Scheduler For Testing."""
block_size = vllm_config.cache_config.block_size
kv_cache_config = KVCacheConfig(
num_blocks=num_blocks, # A large number of blocks to hold all requests
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float16,
False))
],
)
vllm_config.cache_config.num_gpu_blocks = num_blocks
return Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
log_stats=True,
structured_output_manager=StructuredOutputManager(vllm_config),
)
_none_hash_initialized = False
def create_request(
request_id: int,
num_tokens: int = 10,
max_tokens: int = 128,
do_remote_decode: bool = False,
do_remote_prefill: bool = False,
use_all_1s_for_prompt_tokens: bool = False,
num_remote_blocks: int = 3,
block_size: int = 16,
) -> Request:
"""Make dummy request for testing."""
global _none_hash_initialized
if not _none_hash_initialized:
init_none_hash(hash)
_none_hash_initialized = True
block_hasher = get_request_block_hasher(block_size, hash)
kv_transfer_params: Optional[dict[str, Any]] = None
if do_remote_decode:
assert not do_remote_prefill
kv_transfer_params = dict(do_remote_prefill=False,
do_remote_decode=True)
elif do_remote_prefill:
kv_transfer_params = dict(do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id="my-engine-id",
remote_block_ids=list(
range(num_remote_blocks)),
remote_host="my-host",
remote_port=1234,
remote_tp_size=1)
max_tokens = 1 if do_remote_decode else max_tokens
sampling_params = SamplingParams(max_tokens=max_tokens)
if use_all_1s_for_prompt_tokens:
prompt_token_ids = [1] * num_tokens
else:
prompt_token_ids = [i * request_id for i in range(num_tokens)]
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
multi_modal_kwargs=None,
multi_modal_placeholders=None,
multi_modal_hashes=None,
pooling_params=[],
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
else:
req = Request(
request_id=f"id-{request_id}",
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
pooling_params=[],
eos_token_id=EOS_TOKEN_ID,
block_hasher=block_hasher,
)
req.kv_transfer_params = kv_transfer_params
return req
def create_model_runner_output(
reqs: list[Request],
finished_sending: Optional[list[str]] = None,
finished_recving: Optional[list[str]] = None,
use_eos: bool = False,
) -> ModelRunnerOutput:
"""Make dummy model runner output for testing."""
# Make request data.
req_ids = [req.request_id for req in reqs]
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
# Make sampled tokens.
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]
# Make output data structure.
extra_args = {}
from vllm.v1.worker.kv_connector_model_runner_mixin import \
KVConnectorOutput # type: ignore # noqa
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
finished_recving=finished_recving)
extra_args = {"kv_connector_output": kv_connector_output}
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
**extra_args,
)
else:
model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
**extra_args,
)
return model_runner_output

View File

View File

@@ -0,0 +1,195 @@
import pytest
import torch
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.models.deepseek_mtp import (
CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor,
CustomDeepSeekMultiTokenPredictorLayer)
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
@pytest.fixture
def setup_mtp_layer(self, mocker: MockerFixture):
config = PretrainedConfig(vocab_size=1000,
hidden_size=768,
rms_norm_eps=1e-5)
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
return_value=None)
mocker.patch(
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
return_value=None)
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
return mtp_layer
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer)
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch.object(mtp_layer,
'eh_proj',
return_value=torch.randn(2, 3, 768))
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
kv_cache = torch.randn(2, 3, 768)
previous_hidden_states = torch.randn(2, 3, 768)
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0)
assert output.shape == (2, 3, 768)
class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
@pytest.fixture
def setup_predictor(self, mocker: MockerFixture):
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_model_config = mocker.MagicMock(spec=ModelConfig)
mock_hf_config = mocker.MagicMock()
mock_hf_config.num_hidden_layers = 12
mock_hf_config.num_nextn_predict_layers = 3
mock_hf_config.vocab_size = 30000
mock_model_config.hf_config = mock_hf_config
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = CacheConfig()
mock_vllm_config.quant_config = mocker.MagicMock()
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
predictor = CustomDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
return predictor
def test_init(self, mocker: MockerFixture, setup_predictor):
predictor = setup_predictor
assert predictor.num_mtp_layers == 3
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
@pytest.mark.parametrize(
'kv_caches, inputs_embeds',
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
inputs_embeds):
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
# todo: need or not?
# predictor.num_mtp_layers = 1
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
output = predictor.forward(input_ids, positions, kv_caches, None, None,
inputs_embeds, 0)
mock_layer.assert_called_once()
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
return_value=None)
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
result_logits = predictor.compute_logits(hidden_states=hidden_states,
sampling_metadata=None)
predictor.logits_processor.assert_called_once()
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
class TestCustomDeepSeekMTP(PytestBase):
@pytest.fixture
def setup_mtp(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.model_config.hf_config.num_hidden_layers = 12
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
vllm_config.cache_config = mocker.MagicMock()
vllm_config.quant_config = mocker.MagicMock()
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
return mtp
def test_init(self, mocker: MockerFixture, setup_mtp):
mtp = setup_mtp
assert isinstance(mtp, CustomDeepSeekMTP)
def test_forward(self, mocker: MockerFixture, setup_mtp):
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
spec_step_idx = 0
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

View File

@@ -0,0 +1,295 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
import torch
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.models.deepseek_v2 import (
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
CustomDeepseekV2RowParallelLinear,
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
@pytest.fixture
def base_config():
config = PretrainedConfig(
hidden_size=128,
num_attention_heads=8,
num_hidden_layers=2,
intermediate_size=256,
hidden_act="silu",
rms_norm_eps=1e-6,
rope_theta=10000.0,
max_position_embeddings=2048,
n_routed_experts=4,
n_shared_experts=1,
moe_intermediate_size=256,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
first_k_dense_replace=0,
moe_layer_freq=1,
kv_lora_rank=16,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
topk_method="noaux_tc",
scoring_func="softmax",
norm_topk_prob=True,
n_group=1,
topk_group=1,
vocab_size=10000,
)
return config
@pytest.fixture
def vllm_config(base_config):
model_config = SimpleNamespace(
hf_config=base_config,
tensor_parallel_size=1,
dtype=torch.float32,
use_mla=False,
quant_config=None,
max_model_len=2048,
)
cache_config = CacheConfig()
vllm_config = Mock()
vllm_config.model_config = model_config
vllm_config.cache_config = cache_config
vllm_config.quant_config = None
return vllm_config
@pytest.fixture
def mock_distributed():
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
pp_group = Mock(spec=GroupCoordinator)
pp_group.rank_in_group = 0
pp_group.world_size = 1
mock_vllm_config = Mock()
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \
patch("vllm_ascend.models.deepseek_v2.get_ep_group", return_value=ep_group), \
patch("vllm_ascend.models.deepseek_v2.get_dp_group", return_value=dp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
patch("torch.npu.current_device", return_value=0):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
return_value=forward_context):
yield
def test_custom_deepseek_v2_silu_and_mul():
torch.set_default_device("cpu")
silu = CustomDeepseekV2SiluAndMul()
assert silu.weight_scale is None
x = torch.randn(2, 4)
output = silu.forward_oot(x)
assert output.shape == (2, 2)
weight_scale = Mock(return_value=torch.tensor(0.1))
silu = CustomDeepseekV2SiluAndMul(weight_scale=weight_scale)
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
dynamic_scale = torch.randn(2, 1)
with patch("torch_npu.npu_dequant_swiglu_quant",
return_value=torch.randn(2, 4)):
output = silu.forward_oot((quant_x, dynamic_scale))
assert output.shape == (2, 4)
def test_custom_deepseek_v2_merged_replicated_linear(mock_distributed):
linear = CustomDeepseekV2MergedReplicatedLinear(input_size=128,
output_sizes=[64, 64],
bias=False,
quant_config=None)
assert linear.output_sizes == [64, 64]
param = Mock()
param.data = torch.zeros(128, 128)
param.output_dim = 1
param.is_gguf_weight = False
param.is_gguf_weight_type = False
loaded_weight = torch.randn(128, 64)
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
with pytest.raises(AssertionError):
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
@pytest.mark.parametrize("cls", [
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
CustomDeepseekV2RowParallelLinear
])
def test_row_parallel_linear(cls, mock_distributed):
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
linear.quant_method = Mock()
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
input_ = torch.randn(2, 4, 128)
with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim",
return_value=[torch.randn(2, 4, 64)]):
linear.input_is_parallel = False
output = linear(input_, is_prefill=True)
assert output[0].shape == (2, 4, 64)
linear.input_is_parallel = True
output = linear(input_, is_prefill=False)
assert output[0].shape == (2, 4, 64)
def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
mlp = CustomDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=None)
assert isinstance(mlp.act_fn, CustomDeepseekV2SiluAndMul)
x = torch.randn(2, 4, 128)
output = mlp(x)
assert output.shape == (2, 4, 128)
with patch("vllm_ascend.models.deepseek_v2.QuantizationConfig"
) as mock_quant_config:
mock_quant_config.name = "w8a8dynamic"
with pytest.raises(NotImplementedError):
CustomDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=mock_quant_config,
force_replicate=False)
with pytest.raises(ValueError):
CustomDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="relu",
quant_config=None)
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = CustomDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=16,
kv_lora_rank=16,
cache_config=CacheConfig(),
quant_config=None,
prefix="layers.0.self_attn")
assert attn.debug_layer_idx == 0
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(attn.mla_attn,
"__call__",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=None,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
# 创建一个简单的配置对象
class SimpleConfig:
def __init__(self):
self.vocab_size = 10000
self.hidden_size = 128
config = SimpleConfig()
# 直接创建lmhead和logits_processor
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
logits_processor = LogitsProcessor(config.vocab_size)
# 创建模拟输出
mock_output = torch.randn(2, 4, config.hidden_size)
mock_logits = torch.randn(2, 4, config.vocab_size)
# 直接测试logits_processor
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
with patch.object(logits_processor,
"_gather_logits",
return_value=mock_logits):
logits = logits_processor(lmhead, mock_output)
assert logits.shape == (2, 4, config.vocab_size)

View File

@@ -0,0 +1,424 @@
import pytest
import torch
import torch.nn.functional as F
from pytest_mock import MockerFixture
from tests.ut.base import PytestBase
from vllm_ascend.models.qwen2_5_vl import (
AscendQwen2_5_VisionAttention, AscendQwen2_5_VisionBlock,
AscendQwen2_5_VisionPatchEmbed, AscendQwen2_5_VisionRotaryEmbedding,
AscendQwen2_5_VisionTransformer, AscendQwen2_5_VLForConditionalGeneration)
class TestAscendQwen2_5_VisionAttention(PytestBase):
def init_attention(
self,
mocker,
embed_dim=1000,
num_heads=10,
projection_size=100,
quant_config=None,
prefix="",
):
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionAttention.__init__")
attention = AscendQwen2_5_VisionAttention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
args, kwargs = mocker_attn.call_args
assert args == (embed_dim, num_heads, projection_size, None, "")
assert not kwargs
attention.num_attention_heads_per_partition = num_heads
return attention
def test_attn_init_should_normal(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 10
projection_size = 100
quant_config = None
prefix = ""
vit = self.init_attention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
mocker=mocker,
)
assert vit.embed_dim == 1000
assert vit.hidden_size_per_attention_head == 10
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 7
projection_size = 100
quant_config = None
prefix = ""
with pytest.raises(AssertionError):
# projection_size should divided by num heads
self.init_attention(
mocker=mocker,
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
def test_split_qkv(self, mocker: MockerFixture):
attention = self.init_attention(mocker=mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
q, k, v = attention.split_qkv(torch.rand((100, 10, 300)))
assert q.shape == (100, 10, 10, 10)
assert k.shape == (100, 10, 10, 10)
assert v.shape == (100, 10, 10, 10)
def test_attn_forward(self, mocker: MockerFixture):
attention = self.init_attention(mocker=mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
qkv = lambda x: (x, 0) # noqa
split_qkv = lambda x: [ #noqa
torch.rand((100, 3, 10, 128)) for i in range(3)
] # noqa
npu_rotary_mul = lambda q, cos, sin: q # noqa
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
proj = lambda x: (x, 0) # noqa
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
mocker_split_qkv = mocker.patch.object(
attention,
"split_qkv",
side_effect=split_qkv,
)
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
side_effect=npu_rotary_mul)
mocker_npu_flash_attention_unpad = mocker.patch(
"torch_npu._npu_flash_attention_unpad",
side_effect=_npu_flash_attention_unpad,
)
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
attention.__dict__["qkv"] = mocker_qkv
attention.__dict__["split_qkv"] = mocker_split_qkv
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
attention.__dict__["_npu_flash_attention_unpad"] = (
mocker_npu_flash_attention_unpad)
attention.__dict__["proj"] = mocker_proj
output = attention.forward(
x=x,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin,
)
qkv_args, qkv_kwargs = mocker_qkv.call_args
assert qkv_args == (x, )
assert not qkv_kwargs
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
assert split_qkv_args == (x, )
assert not split_qkv_kwargs
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
assert npu_rotary_mul_args[1:] == (cos, sin)
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
assert not npu_rotary_mul_kwargs
assert output.shape == torch.Size([100, 3, 1280])
class TestAscendQwen2_5_VisionBlock(PytestBase):
def init_vision_block(
self,
mocker,
dim=100,
num_heads=10,
mlp_hidden_dim=100,
):
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__",
return_value=None,
)
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionAttention.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_block = AscendQwen2_5_VisionBlock(
dim=dim,
num_heads=num_heads,
mlp_hidden_dim=mlp_hidden_dim,
)
args, kwargs = mocker_vit.call_args
assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "")
assert not kwargs
args1, kwargs1 = mocker_attn.call_args
assert not args1
assert kwargs1 == {
"embed_dim": dim,
"num_heads": num_heads,
"projection_size": dim,
"quant_config": None,
"prefix": ".attn",
}
return vision_block
def test_init_vision_block_should_normal(
self,
mocker: MockerFixture,
):
vision_block = self.init_vision_block(mocker)
assert isinstance(vision_block, AscendQwen2_5_VisionBlock)
def test_vision_block_forward(self, mocker: MockerFixture):
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
vision_block = self.init_vision_block(mocker)
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
vision_block.__dict__["attn"] = mocker_attn
vision_block.__dict__["mlp"] = mocker_mlp
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
_, attn_kwargs = mocker_attn.call_args
assert attn_kwargs == {
"cu_seqlens": cu_seqlens,
"cos": cos,
"sin": sin,
}
assert torch.all(x * 3 == output)
class TestAscendQwen2_5_VisionPatchEmbed(PytestBase):
def test_forward(self):
patch_embed = AscendQwen2_5_VisionPatchEmbed()
ret = patch_embed(torch.rand((120, 1176)))
assert ret.shape == (120, 1152)
class TestAscendQwen2_5_VisionRotaryEmbedding(PytestBase):
def init_rotary_embedding(
self,
mocker,
dim=128,
):
mocker_ebed = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
rotary_embedding = AscendQwen2_5_VisionRotaryEmbedding(dim=dim, )
args, kwargs = mocker_ebed.call_args
assert args == (dim, 10000.0)
assert not kwargs
return rotary_embedding
def test_init_rotary_embedding_should_normal(self, mocker: MockerFixture):
rotary_embedding = self.init_rotary_embedding(mocker)
assert isinstance(rotary_embedding,
AscendQwen2_5_VisionRotaryEmbedding)
class TestAscendQwen2_5_VisionTransformer(PytestBase):
input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
def init_vision_transformer(
self,
mocker,
):
norm_eps = 1e-6
vision_config = mocker.MagicMock()
vision_config.patch_size = 16
vision_config.temporal_patch_size = 2
vision_config.in_channels = 3
vision_config.hidden_act = "gelu"
vision_config.depth = 0
vision_config.num_heads = 10
vision_config.hidden_size = 300
mocker.patch(
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
mocker.patch("vllm.distributed.utils.divide", return_value=100)
mocker.patch(
"vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size",
return_value=2,
)
mocker.patch(
"vllm.model_executor.layers.linear.divide",
return_value=2,
)
mocker.patch(
"vllm.model_executor.layers.linear.get_tensor_model_parallel_rank",
return_value=0)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size",
return_value=2,
)
vision_transformer = AscendQwen2_5_VisionTransformer(
vision_config,
norm_eps,
)
assert not vision_transformer.interleaved
return vision_transformer
def test_init_vision_transformer(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
assert isinstance(vision_transformer, AscendQwen2_5_VisionTransformer)
@pytest.mark.parametrize(
"interleaved, expected",
[
(
False,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
]),
),
(
True,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 1].cos(),
]),
),
],
)
def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_transformer.__dict__["interleaved"] = interleaved
vision_transformer.__dict__["hidden_size_per_attention_head"] = 2
vision_transformer.hidden_size_per_attention_head = 4
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
assert cos_new.shape == (1, 32, 1, 2)
def test_forward(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
x = torch.randn(1, 3, 224, 224)
grid_thw = torch.tensor([[1, 4, 4]])
mocker_patch_embed = mocker.patch.object(
vision_transformer,
"patch_embed",
side_effect=lambda _: torch.randn(16, 512), # noqa
)
mocker_rot_pos_emb = mocker.patch.object(
vision_transformer,
"rot_pos_emb",
side_effect=lambda _: torch.randn(16, 64), # noqa
)
mocker_get_window_index = mocker.patch.object(
vision_transformer,
"get_window_index",
side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa
)
mocker_cal_cos_sin = mocker.patch.object(
vision_transformer,
"cal_cos_sin",
side_effect=lambda _:
(torch.randn(16, 32), torch.randn(16, 32)), # noqa
)
mocker_merger = mocker.patch.object(
vision_transformer,
"merger",
side_effect=lambda _: torch.randn(16, 256), # noqa
)
vision_transformer.__dict__["vision_blocks"] = [
lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa
]
vision_transformer.__dict__["patch_embed"] = mocker_patch_embed
vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb
vision_transformer.__dict__[
"get_window_index"] = mocker_get_window_index
vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin
vision_transformer.__dict__["merger"] = mocker_merger
vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2]
vision_transformer.__dict__["spatial_merge_unit"] = 2
ret = vision_transformer.forward(x, grid_thw)
assert ret.shape == (8, 256)
mocker_patch_embed.assert_called_with(x)
mocker_rot_pos_emb.assert_called_with(grid_thw)
mocker_get_window_index.assert_called_with(grid_thw)
mocker_cal_cos_sin.assert_called_once()
mocker_merger.assert_called_once()
class TestAscendQwen2_5_VLForConditionalGeneration(PytestBase):
def test_init_vl_for_conditional_generation(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.vision_config = "vision_config"
vllm_config.rms_norm_eps = 1e-5
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker_vl = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__",
return_value=None,
)
mocker_vit = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionTransformer.__init__",
return_value=None,
)
vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration(
vllm_config=vllm_config)
args, kwargs = mocker_vl.call_args
assert not args
assert kwargs == {"vllm_config": vllm_config, "prefix": ""}
mocker_vit.assert_called_once()
assert isinstance(
vl_for_conditional_generation,
AscendQwen2_5_VLForConditionalGeneration,
)

View File

@@ -0,0 +1,422 @@
import pytest
import torch
import torch.nn.functional as F
from pytest_mock import MockerFixture
from vllm.model_executor.models.qwen2_5_vl import \
Qwen2_5_VLForConditionalGeneration
from tests.ut.base import PytestBase
from vllm_ascend.models.qwen2_5_vl_without_padding import (
AscendQwen2_5_VisionAttention_Without_Padding,
AscendQwen2_5_VisionBlock_Without_Padding,
AscendQwen2_5_VisionPatchEmbed_Without_Padding,
AscendQwen2_5_VisionTransformer_Without_Padding,
AscendQwen2_5_VLForConditionalGeneration_Without_Padding)
class TestAscendQwen2_5_VisionAttention_Without_Padding(PytestBase):
def init_attention(
self,
mocker,
embed_dim=1000,
num_heads=10,
projection_size=100,
quant_config=None,
prefix="",
):
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.Qwen2_5_VisionAttention.__init__"
)
attention = AscendQwen2_5_VisionAttention_Without_Padding(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
args, kwargs = mocker_attn.call_args
assert args == (embed_dim, num_heads, projection_size, None, "")
assert not kwargs
attention.num_attention_heads_per_partition = num_heads
return attention
def test_vit_init_should_normal(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 10
projection_size = 100
quant_config = None
prefix = ""
vit = self.init_attention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
mocker=mocker,
)
assert vit.embed_dim == 1000
assert vit.hidden_size_per_attention_head == 10
def test_vit_init_should_raise_error(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 7
projection_size = 100
quant_config = None
prefix = ""
with pytest.raises(AssertionError):
# projection_size should divided by num heads
self.init_attention(
mocker=mocker,
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
def test_vit_forward(self, mocker: MockerFixture):
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
attention = self.init_attention(mocker=mocker)
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
qkv = lambda x: (x, 0) # noqa
split_qkv = lambda x: [ #noqa
torch.rand((100, 3, 10, 128)) for i in range(3)
] # noqa
npu_rotary_mul = lambda q, cos, sin: q # noqa
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
proj = lambda x: (x, 0) # noqa
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
mocker_split_qkv = mocker.patch.object(
attention,
"split_qkv",
side_effect=split_qkv,
)
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
side_effect=npu_rotary_mul)
mocker_npu_flash_attention_unpad = mocker.patch(
"torch_npu._npu_flash_attention_unpad",
side_effect=_npu_flash_attention_unpad,
)
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
attention.__dict__["qkv"] = mocker_qkv
attention.__dict__["split_qkv"] = mocker_split_qkv
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
attention.__dict__["_npu_flash_attention_unpad"] = (
mocker_npu_flash_attention_unpad)
attention.__dict__["proj"] = mocker_proj
output = attention.forward(
x=x,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin,
)
qkv_args, qkv_kwargs = mocker_qkv.call_args
assert qkv_args == (x, )
assert not qkv_kwargs
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
assert split_qkv_args == (x, )
assert not split_qkv_kwargs
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
assert npu_rotary_mul_args[1:] == (cos, sin)
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
assert not npu_rotary_mul_kwargs
assert output.shape == torch.Size([100, 3, 1280])
class TestAscendQwen2_5_VisionBlock_Without_Padding(PytestBase):
def init_vision_block(
self,
mocker,
dim=100,
num_heads=10,
mlp_hidden_dim=100,
):
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__",
return_value=None,
)
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionAttention_Without_Padding.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_block = AscendQwen2_5_VisionBlock_Without_Padding(
dim=dim,
num_heads=num_heads,
mlp_hidden_dim=mlp_hidden_dim,
)
args, kwargs = mocker_vit.call_args
assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "")
assert not kwargs
args1, kwargs1 = mocker_attn.call_args
assert not args1
assert kwargs1 == {
"embed_dim": dim,
"num_heads": num_heads,
"projection_size": dim,
"quant_config": None,
"prefix": ".attn",
}
return vision_block
def test_init_vision_block_should_normal(
self,
mocker: MockerFixture,
):
vision_block = self.init_vision_block(mocker)
assert isinstance(vision_block,
AscendQwen2_5_VisionBlock_Without_Padding)
def test_vision_block_forward(self, mocker: MockerFixture):
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
vision_block = self.init_vision_block(mocker)
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
vision_block.__dict__["attn"] = mocker_attn
vision_block.__dict__["mlp"] = mocker_mlp
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
_, attn_kwargs = mocker_attn.call_args
assert attn_kwargs == {
"cu_seqlens": cu_seqlens,
"cos": cos,
"sin": sin,
}
assert torch.all(x * 3 == output)
class TestAscendQwen2_5_VisionPatchEmbed_Without_Padding(PytestBase):
def test_forward(self):
patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding()
ret = patch_embed(torch.rand((120, 1176)))
assert ret.shape == (120, 1152)
class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase):
input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
def init_vision_transformer(
self,
mocker,
):
norm_eps = 1e-6
vision_config = mocker.MagicMock()
vision_config.patch_size = 16
vision_config.temporal_patch_size = 2
vision_config.in_channels = 3
vision_config.hidden_act = "gelu"
vision_config.depth = 0
vision_config.hidden_size = 1280
vision_config.num_heads = 16
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__",
return_value=None,
)
mocker_vision_rotary_embedding = mocker.patch(
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionPatchEmbed_Without_Padding.__init__",
return_value=None,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_world_size",
return_value=1,
)
mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_rank",
return_value=0,
)
mocker.patch("vllm.distributed.utils.divide", return_value=100)
vision_transformer = AscendQwen2_5_VisionTransformer_Without_Padding(
vision_config,
norm_eps,
)
args, kwargs = mocker_vit.call_args
assert args == (vision_config, norm_eps, None, "")
assert not kwargs
mocker_vision_rotary_embedding.assert_called_once()
return vision_transformer
def test_init_vision_transformer(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
assert isinstance(vision_transformer,
AscendQwen2_5_VisionTransformer_Without_Padding)
@pytest.mark.parametrize(
"interleaved, expected",
[
(
False,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
]),
),
(
True,
torch.tensor([
input_data[0, 0].cos(),
input_data[0, 0].cos(),
input_data[0, 1].cos(),
input_data[0, 1].cos(),
input_data[1, 0].cos(),
input_data[1, 0].cos(),
input_data[1, 1].cos(),
input_data[1, 1].cos(),
]),
),
],
)
def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
vision_transformer.__dict__["interleaved"] = interleaved
vision_transformer.__dict__["hidden_size_per_attention_head"] = 2
vision_transformer.hidden_size_per_attention_head = 4
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
assert cos_new.shape == (1, 4, 1, 2)
assert torch.allclose(cos_new.view(-1), expected)
def test_forward(self, mocker: MockerFixture):
vision_transformer = self.init_vision_transformer(mocker)
x = torch.randn(1, 3, 224, 224)
grid_thw = torch.tensor([[1, 4, 4]])
mocker_patch_embed = mocker.patch.object(
vision_transformer,
"patch_embed",
side_effect=lambda _: torch.randn(16, 512), # noqa
)
mocker_rot_pos_emb = mocker.patch.object(
vision_transformer,
"rot_pos_emb",
side_effect=lambda _: torch.randn(16, 64), # noqa
)
mocker_get_window_index = mocker.patch.object(
vision_transformer,
"get_window_index",
side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa
)
mocker_cal_cos_sin = mocker.patch.object(
vision_transformer,
"cal_cos_sin",
side_effect=lambda _:
(torch.randn(16, 32), torch.randn(16, 32)), # noqa
)
mocker_merger = mocker.patch.object(
vision_transformer,
"merger",
side_effect=lambda _: torch.randn(16, 256), # noqa
)
vision_transformer.__dict__["vision_blocks"] = [
lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa
]
vision_transformer.__dict__["patch_embed"] = mocker_patch_embed
vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb
vision_transformer.__dict__[
"get_window_index"] = mocker_get_window_index
vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin
vision_transformer.__dict__["merger"] = mocker_merger
vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2]
vision_transformer.__dict__["spatial_merge_unit"] = 2
ret = vision_transformer.forward(x, grid_thw)
assert ret.shape == (8, 256)
mocker_patch_embed.assert_called_with(x)
mocker_rot_pos_emb.assert_called_with(grid_thw)
mocker_get_window_index.assert_called_with(grid_thw)
mocker_cal_cos_sin.assert_called_once()
mocker_merger.assert_called_once()
class TestAscendQwen2_5_VLForConditionalGeneration_Without_Padding(PytestBase):
def test_init_vl_for_conditional_generation(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.vision_config = "vision_config"
vllm_config.rms_norm_eps = 1e-5
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker_vl = mocker.patch(
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__",
return_value=None,
)
mocker_vit = mocker.patch(
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionTransformer_Without_Padding.__init__",
return_value=None,
)
vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
vllm_config=vllm_config)
args, kwargs = mocker_vl.call_args
assert not args
assert kwargs == {"vllm_config": vllm_config, "prefix": ""}
mocker_vit.assert_called_once()
assert isinstance(
vl_for_conditional_generation,
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
)
def test_overridden_methods(self):
self.assert_method_overridden(
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
Qwen2_5_VLForConditionalGeneration,
"_process_image_input",
)
self.assert_method_overridden(
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
Qwen2_5_VLForConditionalGeneration,
"_process_video_input",
)
@staticmethod
def assert_method_overridden(subclass, parent, method_name: str):
"""assert subclass override parent method"""
parent_func = parent.__dict__.get(method_name)
child_func = subclass.__dict__.get(method_name)
assert child_func is not None, f"{subclass.__name__} should defined {method_name}"
assert child_func is not parent_func, f"{method_name} should override in {subclass.__name__}"

View File

@@ -0,0 +1,200 @@
import pytest
import torch
from pytest_mock import MockerFixture
from vllm.model_executor.layers.activation import QuickGELU
from tests.ut.base import PytestBase
from vllm_ascend.models.qwen2_vl import (AscendQwen2VisionAttention,
AscendQwen2VisionBlock)
class TestAscendQwen2VisionAttention(PytestBase):
def init_attention(
self,
mocker,
embed_dim=1000,
num_heads=10,
projection_size=100,
quant_config=None,
prefix="",
):
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_vl.Qwen2VisionAttention.__init__")
attention = AscendQwen2VisionAttention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
args, kwargs = mocker_attn.call_args
assert args == (embed_dim, num_heads, projection_size, None, "")
assert not kwargs
attention.num_attention_heads_per_partition = num_heads
return attention
def test_attn_init_should_normal(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 10
projection_size = 100
quant_config = None
prefix = ""
vit = self.init_attention(
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
mocker=mocker,
)
assert vit.hidden_size_per_attention_head == 10
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
embed_dim = 1000
num_heads = 7
projection_size = 100
quant_config = None
prefix = ""
with pytest.raises(AssertionError):
# projection_size should divided by num heads
self.init_attention(
mocker=mocker,
embed_dim=embed_dim,
num_heads=num_heads,
projection_size=projection_size,
quant_config=quant_config,
prefix=prefix,
)
def test_attn_forward(self, mocker: MockerFixture):
attention = self.init_attention(mocker=mocker)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
qkv = lambda x: (x, 0) # noqa
split_qkv = lambda x: [ #noqa
torch.rand((100, 3, 10, 128)) for i in range(3)
] # noqa
npu_rotary_mul = lambda q, cos, sin: q # noqa
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
proj = lambda x: (x, 0) # noqa
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
mocker_split_qkv = mocker.patch.object(
attention,
"split_qkv",
side_effect=split_qkv,
)
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
side_effect=npu_rotary_mul)
mocker_npu_flash_attention_unpad = mocker.patch(
"torch_npu._npu_flash_attention_unpad",
side_effect=_npu_flash_attention_unpad,
)
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
attention.__dict__["qkv"] = mocker_qkv
attention.__dict__["split_qkv"] = mocker_split_qkv
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
attention.__dict__["_npu_flash_attention_unpad"] = (
mocker_npu_flash_attention_unpad)
attention.__dict__["proj"] = mocker_proj
output = attention.forward(
x=x,
cu_seqlens=cu_seqlens,
cos=cos,
sin=sin,
)
qkv_args, qkv_kwargs = mocker_qkv.call_args
assert qkv_args == (x, )
assert not qkv_kwargs
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
assert split_qkv_args == (x, )
assert not split_qkv_kwargs
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
assert npu_rotary_mul_args[1:] == (cos, sin)
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
assert not npu_rotary_mul_kwargs
assert output.shape == torch.Size([100, 3, 1280])
class TestAscendQwen2VisionBlock(PytestBase):
def init_vision_block(
self,
mocker,
dim=100,
num_heads=10,
mlp_ratio=0.5,
):
mocker_vit = mocker.patch(
"vllm.model_executor.models.qwen2_vl.Qwen2VisionBlock.__init__",
return_value=None,
)
mocker_attn = mocker.patch(
"vllm_ascend.models.qwen2_vl.AscendQwen2VisionAttention.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
vision_block = AscendQwen2VisionBlock(
dim=dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
)
args, kwargs = mocker_vit.call_args
assert args == (dim, num_heads, mlp_ratio, QuickGELU, None, None, "")
assert not kwargs
args1, kwargs1 = mocker_attn.call_args
assert not args1
assert kwargs1 == {
"embed_dim": dim,
"num_heads": num_heads,
"projection_size": dim,
"quant_config": None,
"prefix": ".attn",
}
return vision_block
def test_init_vision_block_should_normal(
self,
mocker: MockerFixture,
):
vision_block = self.init_vision_block(mocker)
assert isinstance(vision_block, AscendQwen2VisionBlock)
def test_vision_block_forward(self, mocker: MockerFixture):
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
cu_seqlens = torch.tensor([10, 50, 100])
cos = torch.rand((1, 100, 1, 128))
sin = torch.rand((1, 100, 1, 128))
vision_block = self.init_vision_block(mocker)
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
vision_block.__dict__["attn"] = mocker_attn
vision_block.__dict__["mlp"] = mocker_mlp
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
_, attn_kwargs = mocker_attn.call_args
assert attn_kwargs == {
"cu_seqlens": cu_seqlens,
"cos": cos,
"sin": sin,
}
assert torch.all(x * 3 == output)

View File

@@ -0,0 +1,98 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
import unittest
import pytest
import torch
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
class TestCustomQwen3MoeForCausalLM:
def test_class_inheritance(self):
assert issubclass(CustomQwen3MoeForCausalLM, Qwen3MoeForCausalLM)
@pytest.mark.parametrize("key, expected", [
("qkv_proj", ["q_proj", "k_proj", "v_proj"]),
("gate_up_proj", ["gate_proj", "up_proj"]),
("experts",
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]),
])
def test_packed_modules_mapping(self, key, expected):
assert CustomQwen3MoeForCausalLM.packed_modules_mapping[
key] == expected
def test_packed_modules_mapping_structure(self):
expected_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
"experts": [
"experts.0.gate_proj", "experts.0.up_proj",
"experts.0.down_proj"
]
}
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
class DummyRMSNorm:
def __init__(self, dim: int, eps: float = 1e-6):
self.dim = dim
self.eps = eps
def __call__(self, x):
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
denom = (mean_sq + self.eps).sqrt()
return x / denom
class TestCustomQwen3MoeAttention(unittest.TestCase):
def setUp(self):
self.batch = 2
self.seq_len = 3
self.q_size = 8
self.kv_size = 8
self.head_dim = 4
self.rms_eps = 1e-6
total_dim = self.q_size + 2 * self.kv_size
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
dtype=torch.float32).reshape(
self.batch, self.seq_len, total_dim)
def test_constant_input_normalization(self):
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
dtype=torch.float32)
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
expected_q = torch.full((1, 1, self.q_size), norm_val)
expected_k = torch.full((1, 1, self.kv_size), norm_val)
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
self.assertTrue(torch.equal(v, expected_v))

View File

@@ -0,0 +1,32 @@
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
MSEventKey)
class Testbase(TestBase):
def test_ms_event_key(self):
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
def test_ms_attention_metadata_split_config_default(self):
config = MSAttentionMetadataSplitConfig()
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
def test_ms_attention_metadata_split_config_custom(self):
config = MSAttentionMetadataSplitConfig(
num_micro_batches=4,
min_total_tokens_to_split=512,
min_prefill_tokens_to_split=128)
self.assertEqual(config.num_micro_batches, 4)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)

View File

@@ -0,0 +1,47 @@
import pytest
from pytest_mock import MockFixture
from tests.ut.base import PytestBase
from vllm_ascend.multistream.decorator import set_multistream_support
class Context:
def __init__(self, attn_metadata=None):
self.attn_metadata = attn_metadata
class TestDecorator(PytestBase):
@pytest.mark.parametrize(
'layer_context, microbatch_context, expected_metadata', [
((-1, None, None), -1, {
"original": True
}),
((-1, None, None), 0, {
"original": True
}),
((0, None, None), -1, {
"original": True
}),
((0, None, [{
"new": True
}]), 0, {
"new": True
}),
])
def test_decorator(self, mocker: MockFixture, layer_context,
microbatch_context, expected_metadata):
def context_func():
return Context(attn_metadata={"original": True})
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
return_value=layer_context)
mocker.patch(
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
return_value=microbatch_context)
context = set_multistream_support()(context_func)()
assert context.attn_metadata == expected_metadata

View File

@@ -0,0 +1,198 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import PytestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
MultiStreamPreTransformerLayer)
from vllm_ascend.multistream.metadata import MultiStreamMetadata
# === fixture: mock tensor input ===
@pytest.fixture
def input_tensors():
return [torch.randn(2, 128), torch.randn(2, 128)]
# === mock get_forward_context ===
class DummyContext:
def __init__(self, attn_metadata):
self.attn_metadata = attn_metadata
class TestMultiStreamPreTransformerLayer(PytestBase):
# === test when multistream_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta")
layer = MultiStreamPreTransformerLayer(multistream_metadata=None)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "dummy_meta"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when attn_metadata is None ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx,
input_tensors):
mock_get_ctx.return_value = DummyContext(attn_metadata=None)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out is None
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=False (no split needed) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.split_micro_batch.return_value = (False, "same_attn",
input_tensors, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == "same_attn"
assert input_out == input_tensors
mock_set_ctx.assert_called_once_with(-1, None, None)
# === test when do_ms=True (split occurred) ===
@patch("vllm_ascend.multistream.layers.get_forward_context")
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
dummy_attn = "original_attn"
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
split_inputs = [[t[:1], t[1:]] for t in input_tensors]
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.start_layer = 2
dummy_metadata.split_micro_batch.return_value = (True,
["attn1", "attn2"],
split_inputs, None)
layer = MultiStreamPreTransformerLayer(
multistream_metadata=dummy_metadata)
attn_out, input_out = layer.forward(input_tensors)
assert attn_out == ["attn1", "attn2"]
assert input_out == split_inputs
mock_set_ctx.assert_called_once_with(2, dummy_metadata,
["attn1", "attn2"])
class TestMultiStreamPostTransformerLayer(PytestBase):
def test_post_forward_metadata_none(self, input_tensors):
layer = MultiStreamPostTransformerLayer(multistream_metadata=None)
output = layer.forward(input_tensors)
assert output == input_tensors
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
dummy_metadata.ms_config = None
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
assert output == input_tensors
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx,
input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
5, # layer_index
dummy_metadata, # ms_metadata
"dummy_attn_metadata" # ms_attn_metadata
)
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors)
# check wait_event
dummy_metadata.try_wait_event.assert_called_once_with(
9, # end_layer - 1
3, # num_micro_batches - 1
MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx,
mock_get_ctx, input_tensors):
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
calculate_stream=MagicMock(),
communicate_stream=MagicMock(),
start_layer=0,
end_layer=1,
event_keys=[],
multistream_config=None,
)
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
dummy_metadata.ms_config.num_micro_batches = 4
dummy_metadata.end_layer = 10
mock_get_ctx.return_value = (
3, # layer_index
dummy_metadata,
"dummy_attn_metadata")
dummy_metadata.merge_micro_batches.return_value = "merged_result"
layer = MultiStreamPostTransformerLayer(
multistream_metadata=dummy_metadata)
output = layer.forward(input_tensors, wait_layer_index=7)
dummy_metadata.try_wait_event.assert_called_once_with(
7, 3, MSEventKey.FFN_AR_FINISH)
mock_reset_ctx.assert_called_once()
assert output == "merged_result"

View File

@@ -0,0 +1,246 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
MultiStreamMetadata,
MultiStreamStepMetadata,
split_micro_batches_tensors)
class TestMetaData(TestBase):
def setUp(self):
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
self.test_tensors = torch.randn(100, 1024)
self.test_tensors_dict = {
'query': torch.randn(100, 1024),
'key': torch.randn(100, 1024),
'value': torch.randn(100, 1024)
}
self.split_index = 50
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
self.metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
def test_split_micro_batches_tensors(self):
test_tensors_list_res = split_micro_batches_tensors(
self.test_tensors_list, self.split_index)
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
self.split_index)
keys = ['query', 'key', 'value']
test_tensors_dict_res = split_micro_batches_tensors(
self.test_tensors_dict, self.split_index, keys)
for i in range(3):
self.assertEqual(len(test_tensors_list_res[i][0]),
self.split_index)
self.assertEqual(
len(test_tensors_list_res[i][0]) +
len(test_tensors_list_res[i][1]), 100)
self.assertEqual(len(test_tensors_res[0]), self.split_index)
self.assertEqual(
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
for key in keys:
self.assertEqual(len(test_tensors_dict_res[0][key]),
self.split_index)
self.assertEqual(
len(test_tensors_dict_res[0][key]) +
len(test_tensors_dict_res[1][key]), 100)
def test_default_init_multistream_step_metadata(self):
metadata = MultiStreamStepMetadata()
self.assertIsNone(metadata.comm_stream)
self.assertIsNone(metadata.before_comm_event)
self.assertIsNone(metadata.after_comm_event)
def test_custom_init_multistream_step_metadata(self):
mockStream = MagicMock(spec=torch.npu.Stream)
mockEvent1 = MagicMock(spec=torch.npu.Event)
mockEvent2 = MagicMock(spec=torch.npu.Event)
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
self.assertEqual(metadata.comm_stream, mockStream)
self.assertEqual(metadata.before_comm_event, mockEvent1)
self.assertEqual(metadata.after_comm_event, mockEvent2)
def test_default_init_multistream_config(self):
config = MultiStreamConfig()
self.assertEqual(config.min_total_tokens_to_split, 256)
self.assertEqual(config.min_prefill_tokens_to_split, 64)
self.assertEqual(config.num_micro_batches, 2)
self.assertEqual(config.imbalance_ratio, 0.1)
def test_custom_init_multistream_config(self):
config = MultiStreamConfig(512, 128, 1, 0.2)
self.assertEqual(config.min_total_tokens_to_split, 512)
self.assertEqual(config.min_prefill_tokens_to_split, 128)
self.assertEqual(config.num_micro_batches, 1)
self.assertEqual(config.imbalance_ratio, 0.2)
def test_init_multistream_metadata(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock()]
multistream_config = MagicMock(spec=MultiStreamConfig)
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertEqual(metadata.calculate_stream, mock_stream)
self.assertEqual(metadata.communicate_stream, mock_stream)
self.assertEqual(metadata.start_layer, 1)
self.assertEqual(metadata.end_layer, 3)
self.assertEqual(metadata.ms_config, multistream_config)
self.assertTrue(metadata.causal_lm)
def test_build_events(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
with patch('torch.npu.Event', return_value=mock_event):
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MultiStreamConfig(
num_micro_batches=2,
min_total_tokens_to_split=256,
min_prefill_tokens_to_split=64)
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
expected_events = {
0: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
1: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
},
2: {
0: {
event_keys[0]: mock_event
},
1: {
event_keys[0]: mock_event
}
}
}
self.assertEqual(metadata.ms_events, expected_events)
def test_build_ms_split_config(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
multistream_config.num_micro_batches = 2
multistream_config.min_total_tokens_to_split = 256
multistream_config.min_prefill_tokens_to_split = 64
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
self.assertIsNotNone(metadata.ms_split_config)
self.assertEqual(metadata.ms_split_config.num_micro_batches,
multistream_config.num_micro_batches)
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
multistream_config.min_total_tokens_to_split)
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
multistream_config.min_prefill_tokens_to_split)
def test_try_wait_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_wait_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.wait.assert_called_once()
def test_try_record_event(self):
mock_stream = MagicMock(spec=torch.npu.Stream)
mock_event = MagicMock(spec=torch.npu.Event)
event_keys = [MagicMock(spec=MSEventKey)]
multistream_config = MagicMock(spec=MultiStreamConfig)
with patch('torch.npu.Event', return_value=mock_event):
metadata = MultiStreamMetadata(
calculate_stream=mock_stream,
communicate_stream=mock_stream,
start_layer=1,
end_layer=3,
event_keys=event_keys,
multistream_config=multistream_config)
metadata.try_record_event(layer_index=1,
micro_batch_index=0,
event_key=event_keys[0])
mock_event.record.assert_called_once()
def test_merge_batches_none_input(self):
input_tensors = None
result = self.metadata.merge_micro_batches(input_tensors)
self.assertIsNone(result)
def test_merge_batches_single_tensor_input(self):
input_tensors = [torch.tensor([1, 2, 3])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 1)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
def test_merge_batches_list_of_tensors_input(self):
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertEqual(result, input_tensors)
def test_merge_batches_nested_list_input(self):
input_tensors = [[torch.tensor([1, 2]),
torch.tensor([3, 4])],
[torch.tensor([5, 6]),
torch.tensor([7, 8])]]
result = self.metadata.merge_micro_batches(input_tensors)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))

View File

@@ -0,0 +1,147 @@
from unittest.mock import MagicMock
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
model_input_split_v1_mla_attn,
split_attn_int_type,
split_attn_tensor_type)
class TestMsSplit(TestBase):
def test_decode_only(self):
result = compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.DecodeOnly,
num_tokens=10)
self.assertEqual(result, [5, 5])
def test_perfect_balance(self):
query_lens = [2, 3, 5]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [5, 2])
def test_imbalance(self):
query_lens = [1, 2, 3, 4]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_query_lens_none(self):
with self.assertRaises(AssertionError):
compute_split_seq_index(
query_lens=None,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
def test_empty_query_lens(self):
query_lens: list[int] = []
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_single_query_len(self):
query_lens = [10]
result = compute_split_seq_index(
query_lens=query_lens,
attn_state=AscendAttentionState.PrefillNoCache,
num_tokens=10)
self.assertEqual(result, [0, 0])
def test_split_attn_tensor_type_middle(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 3
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_start(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 0
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_end(self):
input_tensor = torch.tensor([1, 2, 3, 4, 5])
index = 5
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_tensor_type_empty_tensor(self):
input_tensor = torch.tensor([])
index = 0
expected_result = [torch.tensor([]), torch.tensor([])]
result = split_attn_tensor_type(input_tensor, index)
self.assertEqual(len(result), 2)
self.assertTrue(torch.equal(result[0], expected_result[0]))
self.assertTrue(torch.equal(result[1], expected_result[1]))
def test_split_attn_int_type_index_greater_than_var(self):
var = 5
index = 10
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_equal_to_var(self):
var = 5
index = 5
expected_result = [5, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_less_than_var(self):
var = 10
index = 5
expected_result = [5, 5]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_index_zero(self):
var = 10
index = 0
expected_result = [0, 10]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_var_zero(self):
var = 0
index = 5
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_attn_int_type_both_zero(self):
var = 0
index = 0
expected_result = [0, 0]
result = split_attn_int_type(var, index)
self.assertEqual(result, expected_result)
def test_split_v1_mla_attn_input_none(self):
attn_metadata = None
ascendMLAPrefillMetadata = MagicMock()
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
result = model_input_split_v1_mla_attn(attn_metadata,
ascendMLAPrefillMetadata,
ms_split_config)
self.assertEqual(result, [None])

View File

@@ -0,0 +1,17 @@
{
"moe_layer_count":
1,
"layer_list": [{
"layer_id":
0,
"device_count":
2,
"device_list": [{
"device_id": 0,
"device_expert": [7, 2, 0, 3, 5]
}, {
"device_id": 1,
"device_expert": [6, 1, 4, 7, 2]
}]
}]
}

View File

@@ -0,0 +1,61 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
@patch("torch_npu.npu_fast_gelu", side_effect=lambda x: x + 1)
def test_QuickGELU_forward(mock_gelu, dummy_tensor):
layer = QuickGELU()
out = layer.forward(dummy_tensor)
expected_out = dummy_tensor + 1
assert torch.allclose(out, expected_out)
mock_gelu.assert_called_once()
@pytest.mark.parametrize("is_310p_return", [True, False])
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = SiluAndMul()
out = layer.forward(dummy_tensor)
if is_310p_return:
expected_arg = dummy_tensor.to(torch.float32)
else:
expected_arg = dummy_tensor
# assert mock_swiglu.call_count == 1
mock_swiglu.assert_called_once()
actual_arg = mock_swiglu.call_args[0][0]
assert torch.allclose(
actual_arg,
expected_arg), "npu_swiglu called with unexpected input"
expected_out = dummy_tensor + 1
assert torch.allclose(out, expected_out)

View File

@@ -0,0 +1,69 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
class TestFusedExpertsMoGE(TestBase):
def test_fused_experts_moge(self):
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
patch('torch_npu.npu_swiglu') as mock_swiglu, \
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
mock_is_310p.return_value = False
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
torch.randn(x[0].shape[0], weight[0].shape[1])
]
mock_swiglu.side_effect = lambda x: x
hidden_states = torch.randn(4, 128)
w1 = torch.randn(4, 256, 128)
w2 = torch.randn(4, 128, 128)
topk_weights = torch.rand(4, 1)
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
top_k = 1
global_num_experts = 4
moe_parallel_config = type(
'MockConfig', (), {
'ep_size': 1,
'tp_size': 1,
'dp_size': 1,
'tp_rank': 0,
'dp_rank': 0,
'ep_rank': 0,
'use_ep': True
})()
output = fused_experts_moge(
hidden_states=hidden_states,
w1=w1,
w2=w2,
moe_parallel_config=moe_parallel_config,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
apply_router_weight_on_input=True,
)
self.assertEqual(output.shape, (4, 128))

View File

@@ -0,0 +1,141 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import json
import os
from typing import List, TypedDict
from unittest import mock
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class TestExpertLoadBalancer(TestBase):
def setUp(self):
_TEST_DIR = os.path.dirname(__file__)
json_file = _TEST_DIR + "/expert_map.json"
with open(json_file, 'r') as f:
self.expert_map: MockData = json.load(f)
self.expert_load_balancer = ExpertLoadBalancer(json_file,
global_expert_num=8)
def test_init(self):
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
torch.Tensor)
self.assertEqual(self.expert_load_balancer.layers_num,
self.expert_map["moe_layer_count"])
self.assertEqual(self.expert_load_balancer.ranks_num,
self.expert_map["layer_list"][0]["device_count"])
def test_generate_index_dicts(self):
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
expected_result = [{
7: 0,
2: 1,
0: 2,
3: 3,
5: 4
}, {
6: 5,
1: 6,
4: 7,
7: 8,
2: 9
}]
self.assertEqual(result, expected_result)
def test_generate_expert_placement_map(self):
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
)
self.assertEqual(expert_placement_map.shape,
(self.expert_load_balancer.layers_num,
self.expert_load_balancer.ranks_num, 8))
self.assertTrue(torch.all(expert_placement_map >= -1))
def test_generate_log2phy_expert_map(self):
layer_id = 0
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
layer_id)
self.assertEqual(log2phy_map.shape,
(self.expert_load_balancer.ranks_num, 8))
self.assertTrue(torch.all(log2phy_map >= -1))
@mock.patch("torch_npu.npu._lazy_init")
@mock.patch("torch.npu.current_device", return_value="cpu")
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
layer_id = 0
rank_id = 0
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
self.assertEqual(rank_local_expert_num, 5)
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
rank_id = 1
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
layer_id, rank_id)
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
dtype=torch.int32).to(
rank_expert_map.device)
self.assertTrue(rank_expert_map.equal(expected_tensor))
def test_get_rank_log2phy_map(self):
layer_id = 0
rank_id = 0
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))
rank_id = 1
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
layer_id, rank_id)
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
dtype=torch.int32).to(
log2phy_map.device)
self.assertTrue(log2phy_map.equal(expected_tensor))
def test_get_global_redundant_expert_num(self):
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
)
expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \
self.expert_map["layer_list"][0]["device_count"] - 8
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)

View File

@@ -0,0 +1,741 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
from tests.ut.base import TestBase
from vllm_ascend.ascend_forward_context import (FusedMoEState,
_get_fused_moe_state)
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
from vllm_ascend.utils import AscendSocVersion, adapt_patch
adapt_patch(True)
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
mock_group.world_size = 4
mock_group.device_group = "mock_group_ep"
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
return mock_group
def mock_dp_and_tp_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
return mock_group
def mock_npu_format_cast(weight_data, format):
return weight_data
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
mock_setup_token_dispatchers = MagicMock()
mock_token_dispatcher_with_allgather = MagicMock()
mock_token_dispatcher_with_all2allv = MagicMock()
mock_token_dispatcher_with_mc2 = MagicMock()
mock_dispatch_result_allgather = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([8, 16], dtype=torch.int64),
"group_list_type": 0,
}
mock_combine_result_allgather = torch.randn(16, 2)
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
mock_dispatch_result_all2allv = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
}
mock_combine_result_all2allv = torch.randn(16, 2)
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
mock_dispatch_result_mc2 = {
"hidden_states": torch.randn(16, 2),
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
"group_list_type": 1,
"dynamic_scale": None,
"assist_info_for_combine": torch.randn(16, 2),
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
}
mock_combine_result_mc2 = torch.randn(16, 2)
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
captured_dispatchers = {}
def capture_register(dispatcher_instance):
key = dispatcher_instance.__class__.__name__
captured_dispatchers[key] = dispatcher_instance
if key == 'TokenDispatcherWithAllGather':
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
elif key == 'TokenDispatcherWithAll2AllV':
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
elif key == 'TokenDispatcherWithMC2':
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
mock_register_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
side_effect=capture_register)
mock_get_token_dispatcher_patcher = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
side_effect=lambda name: captured_dispatchers.get(name))
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
mock_forward_context_obj = MagicMock(
fused_moe_state=FusedMoEState.AllGather,
token_dispatcher=default_mock_token_dispatcher,
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
mc2_mask=torch.zeros(16, dtype=torch.bool),
padded_num_tokens=16,
with_quant=False)
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather'), \
patch('torch.distributed.all_to_all_single'), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
expert_map_path=None
)), \
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.ops.fused_moe.get_forward_context',
return_value=mock_forward_context_obj), \
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
return_value=mock_forward_context_obj):
yield {
'mock_forward_context_obj': mock_forward_context_obj,
'mock_token_dispatcher_with_allgather':
mock_token_dispatcher_with_allgather,
'mock_token_dispatcher_with_all2allv':
mock_token_dispatcher_with_all2allv,
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
}
mock_register_token_dispatcher_patcher.stop()
mock_get_token_dispatcher_patcher.stop()
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
None
)), \
patch('torch_npu.npu_moe_init_routing', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
torch.randn(8, 2)
)), \
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_distribute_combine", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
def default_moe_config():
return {
'num_experts': 8,
'top_k': 2,
'hidden_size': 512,
'intermediate_size': 1024
}
@pytest.fixture
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
return AscendUnquantizedFusedMoEMethod(moe)
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class MockQuantMethod(nn.Module):
def __init__(self, shared_experts, num_tokens):
super().__init__()
if shared_experts:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
torch.randn(num_tokens, 10)))
else:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
class TestAscendFusedMoe:
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = AscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = AscendFusedMoE(**error_config)
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = AscendFusedMoE(**error_config)
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
moe = AscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert moe.quant_method == mock_quant_method
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = AscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = AscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
patch('vllm_ascend.utils.is_310p', return_value=False):
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
expected_shape = (16, 2)
assert result.shape == expected_shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
ep_size = others_param
is_prefill = False
if ep_size == 1:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_allgather']
elif ep_size < 16:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_all2allv']
else:
selected_token_dispatcher = mock_dist_env[
'mock_token_dispatcher_with_mc2']
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, True),
with_quant=False,
token_dispatcher=selected_token_dispatcher)
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
local_num_experts = 2
hidden_size = 2
intermediate_size_per_partition = 4
layer.w13_weight = torch.randn(local_num_experts,
intermediate_size_per_partition * 2,
hidden_size)
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
intermediate_size_per_partition)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
expected_shape = (16, 2)
assert result.shape == expected_shape
class TestExpertsSelector:
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
def test_select_experts(self, mock_dist_env, mock_moe_env,
global_num_experts):
x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids, _ = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
use_grouped_topk=False,
renormalize=True,
topk_group=None,
num_expert_group=None,
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
global_num_experts=global_num_experts)
assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)
class TestUnifiedApplyMLP(TestBase):
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_dynamic_quant')
@patch('torch_npu.npu_dequant_swiglu_quant')
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
mock_npu_dynamic_quant,
mock_npu_grouped_matmul,
mock_is_310p,
mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.fused_moe_state = FusedMoEState.MC2
mock_get_forward_context.return_value = mock_forward_context
mock_is_310p.return_value = False
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 20),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
mock_npu_grouped_matmul.side_effect = [[
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_dequant.return_value = (torch.randn(10,
40,
dtype=torch.bfloat16),
torch.randn(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
w1_scale = torch.randn(5, 40, dtype=torch.float32)
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_forward_context.fused_moe_state,
FusedMoEState.MC2)
mock_npu_dynamic_quant.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_dequant.assert_called_once()
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization(self,
mock_npu_dynamic_quant,
mock_npu_swiglu,
mock_npu_grouped_matmul,
mock_is_310p):
mock_is_310p.return_value = False
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.float16)
], [torch.randn(10, 20, dtype=torch.float16)]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_get_forward_context):
mock_forward_context = MagicMock()
mock_forward_context.with_quant = True
mock_forward_context.fused_moe_state = "NOT_MC2"
mock_get_forward_context.return_value = mock_forward_context
mock_npu_grouped_matmul.side_effect = [[
torch.randn(10, 40, dtype=torch.bfloat16)
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
mock_npu_swiglu.return_value = torch.randn(10,
40,
dtype=torch.bfloat16)
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
127, (10, 40),
dtype=torch.int8),
torch.rand(10,
1,
dtype=torch.float32))
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=provided_dynamic_scale,
group_list_type=1,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=None,
with_quant=True)
mock_get_forward_context.assert_called()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
mock_npu_dynamic_quant.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.bfloat16)
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_dynamic_quant')
def test_unified_apply_mlp_without_quantization_310p(
self, mock_npu_dynamic_quant, mock_npu_swiglu,
mock_npu_grouped_matmul, mock_is_310p):
mock_is_310p.return_value = True
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
[mock_gmm2_out]]
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
hidden_states = torch.randn(10, 20, dtype=torch.float16)
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
topk_scales = torch.randn(10, 1, dtype=torch.float16)
result = unified_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=None,
w2=w2,
w2_scale=None,
group_list=group_list,
dynamic_scale=None,
group_list_type=1,
w1_scale_bias=None,
w2_scale_bias=None,
topk_scales=topk_scales,
with_quant=False)
mock_is_310p.assert_called_once()
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
mock_npu_swiglu.assert_called_once()
self.assertEqual(result.shape, hidden_states.shape)
self.assertEqual(result.dtype, torch.float16)

View File

@@ -0,0 +1,53 @@
from unittest.mock import patch
import pytest
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
@pytest.fixture
def dummy_tensor():
return torch.randn(4, 8, dtype=torch.float16)
def mock_rms_norm(x, weight, eps):
return x + 1, None
def mock_add_rms_norm(x, residual, weight, eps):
return 2 * x, None, 2 * residual
@pytest.mark.parametrize("is_310p_return", [True, False])
@pytest.mark.parametrize("residual",
[None, torch.randn(4, 8, dtype=torch.float32)])
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
residual, dummy_tensor):
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
layer = RMSNorm(hidden_size=32, eps=1e-05)
if residual is not None:
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
if is_310p_return:
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
expected_out_x = expected_arg_x + 1
expected_out_residual = expected_arg_x.to(residual.dtype)
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
expected_out_x = 2 * dummy_tensor
expected_out_residual = 2 * residual
mock_add_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)
assert torch.allclose(out_residual, expected_out_residual)
else:
out_x = layer.forward(dummy_tensor, residual)
expected_out_x = dummy_tensor + 1
mock_rmsnorm.assert_called_once()
assert torch.allclose(out_x, expected_out_x)

363
tests/ut/ops/test_linear.py Normal file
View File

@@ -0,0 +1,363 @@
import os
import unittest
from unittest import mock
import torch
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
AscendMlpMergedColumnParallelLinear,
AscendMlpRowParallelLinear, LinearBase,
QuantizationConfig)
class TestAscendMlpRowParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
self.tensor_parallel_world_size = 2
self.tensor_parallel_rank = 0
self.mlp_tensor_parallel_world_size = 2
self.mlp_tensor_parallel_rank = 1
self.get_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
return_value=self.tensor_parallel_world_size)
self.get_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
return_value=self.tensor_parallel_rank)
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
return_value=self.mlp_tensor_parallel_world_size)
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
return_value=self.mlp_tensor_parallel_rank)
self.get_tensor_model_parallel_world_size_mock = \
self.get_tensor_model_parallel_world_size_patch.start()
self.get_tensor_model_parallel_rank_mock = \
self.get_tensor_model_parallel_rank_patch.start()
self.get_mlp_tensor_model_parallel_world_size_mock = \
self.get_mlp_tensor_model_parallel_world_size_patch.start()
self.get_mlp_tensor_model_parallel_rank_mock = \
self.get_mlp_tensor_model_parallel_rank_patch.start()
self.split_tensor_along_last_dim_patch = mock.patch(
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
self.tensor_model_parallel_all_reduce_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_reduce_mock = \
self.tensor_model_parallel_all_reduce_patch.start()
self.split_tensor_along_last_dim_mock = \
self.split_tensor_along_last_dim_patch.start()
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
mock.MagicMock()
def tearDown(self):
self.get_tensor_model_parallel_world_size_patch.stop()
self.get_tensor_model_parallel_rank_patch.stop()
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
self.get_mlp_tensor_model_parallel_rank_patch.stop()
self.split_tensor_along_last_dim_patch.stop()
self.tensor_model_parallel_all_reduce_patch.stop()
self.get_mlp_tp_group_patch.stop()
def test_init_with_down_proj_prefix(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
prefix="down_proj")
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
self.assertTrue(layer.enable_mlp_optimze)
def test_forward_with_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
input_size=16,
output_size=8,
prefix="down_proj",
input_is_parallel=False,
)
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
def test_forward_without_mlp_optimize(self):
layer = AscendMlpRowParallelLinear(
input_size=16,
output_size=8,
prefix="other",
input_is_parallel=False,
)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once_with(
input_tensor, num_partitions=layer.tp_size)
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
def test_skip_bias_add(self):
layer = AscendMlpRowParallelLinear(
input_size=16,
output_size=8,
skip_bias_add=True,
)
input_tensor = torch.randn(16, 8)
output, bias = layer(input_tensor)
self.assertIsNotNone(bias)
def test_no_reduce_results(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
def test_input_not_parallel(self):
layer = AscendMlpRowParallelLinear(input_size=16,
output_size=8,
input_is_parallel=False)
input_tensor = torch.randn(16, 8)
layer(input_tensor)
self.split_tensor_along_last_dim_mock.assert_called_once()
def test_exception_when_reduce_false_and_bias(self):
with self.assertRaises(ValueError):
AscendMlpRowParallelLinear(input_size=16,
output_size=8,
reduce_results=False,
bias=True,
skip_bias_add=False)
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock distributed functions
self.mlp_tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
self.mlp_tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
self.tp_size_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
self.tp_size_mock = self.tp_size_patch.start()
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
self.tp_rank_patch = \
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
self.tp_rank_mock = self.tp_rank_patch.start()
self.tp_rank_mock.return_value = 1 # Current GPU rank
# Mock divide function (assumed to be in your module)
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
self.divide_mock = self.divide_patch.start()
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
# Mock QuantizationConfig and QuantMethod
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
# Mock LinearBase initialization
self.linear_base_init_patch = mock.patch.object(
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
self.linear_base_init_patch.start()
self.quant_method_mock = mock.MagicMock()
def mock_linear_base_init(self, instance, *args, **kwargs):
instance.quant_method = self.quant_method_mock
instance.params_dtype = mock.MagicMock()
instance.input_size = 16
instance.output_size = 8
instance.output_size_per_partition = 4
instance.params_dtype = torch.float32
def tearDown(self):
self.mlp_tp_size_patch.stop()
self.mlp_tp_rank_patch.stop()
self.tp_size_patch.stop()
self.tp_rank_patch.stop()
self.divide_patch.stop()
self.linear_base_init_patch.stop()
def test_mlp_optimize_initialization(self):
# Test when prefix contains "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.gate_up_proj",
bias=False,
)
# Verify MLP optimization flags
self.assertTrue(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 2)
self.assertEqual(layer.tp_rank, 0)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_regular_parallel_initialization(self):
# Test when prefix does NOT contain "gate_up_proj"
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
prefix="model.layers.0.q_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify regular TP flags
self.assertFalse(layer.enable_mlp_optimze)
self.assertEqual(layer.tp_size, 4)
self.assertEqual(layer.tp_rank, 1)
self.assertEqual(layer.input_size_per_partition, 16)
self.assertEqual(layer.output_size_per_partition, 4)
# Check quant_method.create_weights was called
self.quant_method_mock.create_weights.assert_called_once()
def test_output_sizes_handling(self):
# Test when output_sizes is provided
with mock.patch.object(torch.nn.Module, 'register_parameter'):
layer = AscendMlpColumnParallelLinear(
input_size=16,
output_size=8,
output_sizes=[4, 4],
prefix="model.layers.0.qkv_proj",
quant_config=self.quant_config_mock,
bias=False,
)
# Verify output_partition_sizes
self.assertEqual(layer.output_partition_sizes, [2])
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
def setUp(self):
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
self.mlp_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
self.tensor_world_size_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
self.mlp_world_size_patch.start()
self.tensor_world_size_patch.start()
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
self.mlp_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
self.tensor_rank_patch = \
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
self.mlp_rank_patch.start()
self.tensor_rank_patch.start()
# Mock all_gather methods
self.get_mlp_tp_group_patch = \
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
self.tensor_model_parallel_all_gather_patch = mock.patch(
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
return_value=torch.randn(10, 8))
self.tensor_model_parallel_all_gather_mock = \
self.tensor_model_parallel_all_gather_patch.start()
# Mock AscendMlpColumnParallelLinear's __init__
self.linear_init_patch = mock.patch.object(
AscendMlpColumnParallelLinear,
"__init__",
side_effect=self.mock_linear_init)
self.linear_init_patch.start()
# Create mock objects
self.quant_method_mock = mock.MagicMock()
self.apply_output = torch.randn(2, 8)
self.quant_method_mock.apply.return_value = self.apply_output
def mock_linear_init(self, instance, *args, **kwargs):
torch.nn.Module.__init__(instance)
# Set quant_method and other attributes
instance.quant_method = self.quant_method_mock
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
instance.input_size = 16
instance.output_size = 8
instance.gather_output = False
instance.skip_bias_add = False
instance.return_bias = True
def test_forward_with_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.assertEqual(output.shape, self.apply_output.shape)
def test_forward_without_enable_mlp_optimze(self):
# Setup input
input_tensor = torch.randn(1, 16)
# Create instance with prefix not containing "gate_up_proj"
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
output_sizes=[8],
bias=True,
gather_output=False,
skip_bias_add=False,
params_dtype=torch.float32,
quant_config=None,
prefix="other_proj")
# Call forward
output, bias = layer(input_tensor)
# Validate calls
self.quant_method_mock.apply.assert_called_once_with(
layer, input_tensor, layer.bias)
self.tensor_model_parallel_all_gather_mock.assert_not_called()
self.assertEqual(output.shape, self.apply_output.shape)
def tearDown(self):
self.linear_init_patch.stop()
self.mlp_world_size_patch.stop()
self.tensor_world_size_patch.stop()
self.mlp_rank_patch.stop()
self.tensor_rank_patch.stop()
self.get_mlp_tp_group_mock.stop()
self.tensor_model_parallel_all_gather_mock.stop()

View File

@@ -0,0 +1,318 @@
import math
import unittest
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from tests.ut.base import TestBase
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = _custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=True):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
return_value=False):
result = _custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestAscendRotaryEmbedding(unittest.TestCase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.cos_sin_cache = torch.randn(3, 1, 32)
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
self.max_position, self.rope_theta,
self.is_neox_style, torch.float16)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = self.is_neox_style
@patch('torch.ops._C')
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = self.layer.forward(self.positions, self.query,
self.key)
mock__c.rotary_embedding.assert_called_once()
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = self.layer.forward(self.positions,
non_contig_query,
non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
def test_rope_forward_oot_with_offsets(self):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
self.layer.forward(self.positions, self.query, self.key, offsets)
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
# Test neox_style override
result_q, result_k = self.layer.forward(self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
self.head_size = 32
self.rotary_dim = self.head_size
self.max_position = 16
self.rope_theta = 10000
self.is_neox_style = True
self.scaling_factor = 1
self.layer = None
def _create_layer(self):
self.layer = DeepseekScalingRotaryEmbedding(
self.head_size, self.rotary_dim, self.max_position,
self.rope_theta, self.is_neox_style, self.scaling_factor,
torch.float16)
return self.layer
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
return_value=(self.query,
self.key)) as mock_rope_forward_oot:
q_pe, k_pe = self.layer.forward(self.positions, self.query,
self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_cache_handling(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
self.layer.max_seq_len = 1024
# Test cache situation is true
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
mock_rope_forward_oot.return_value = (self.query, self.key)
q_pe, k_pe = self.layer.forward(self.positions,
self.query,
self.key,
max_seq_len=2048)
mock_set_cache.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
key = torch.randn(1, 32)
mock_rope_forward_oot.return_value = (self.query, key)
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == key.shape
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_npuplatform, mock_rope_forward_oot):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
mock_rope_forward_oot.return_value = (self.query, self.key)
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
mock_rope_forward_oot.assert_called_once()
assert q_pe.shape == self.query.shape
assert k_pe.shape == self.key.shape
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_basic_case(self, mock_npuplatform):
# Test with standard values
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
@patch("vllm.platforms.current_platform.device_type",
new=torch.device("cpu"))
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
new_callable=PropertyMock)
def test_yarn_get_mscale(self, mock_npuplatform):
mock_npuplatform.device_type = torch.device("cpu")
self.layer = self._create_layer()
# test_scale_less_than_or_equal_1
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
# test_scale_greater_than_1:
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = self.layer._yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")

View File

@@ -0,0 +1,606 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from unittest.mock import MagicMock, PropertyMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
AscendSocVersion, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
class TestTokenDispatcherWithMC2(TestBase):
def setUp(self):
self.mc2_group = MagicMock()
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
self.mc2_group.rank_in_group = 0
self.mc2_group.world_size = 8
self.mc2_group_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
return_value=self.mc2_group)
self.mc2_group_patch.start()
self.rank_group_patch = patch("torch.distributed.get_rank",
return_value=0)
self.rank_group_patch.start()
# Mock get_forward_context().mc2_mask
self.forward_context = MagicMock()
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
self.forward_context_patch = patch(
"vllm.forward_context.get_forward_context",
return_value=self.forward_context)
self.forward_context_patch.start()
# Mock get_ascend_soc_version()
self.ascend_soc_version_patch = patch(
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
return_value=AscendSocVersion.A3)
self.ascend_soc_version_patch.start()
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.mc2_group_patch.stop()
self.forward_context_patch.stop()
self.ascend_soc_version_patch.stop()
def test_init(self):
self.assertEqual(self.dispatcher.ep_rank_id, 0)
self.assertEqual(self.dispatcher.ep_world_size, 8)
self.assertFalse(self.dispatcher.with_quant)
self.assertTrue(self.dispatcher.enable_dispatch_v2)
self.assertTrue(self.dispatcher.need_extra_args)
self.assertTrue(self.dispatcher.a3_need_extra_args)
def test_get_dispatch_mc2_kwargs_without_quant(self):
hidden_states = torch.randn(10, 128)
topk_ids = torch.randint(0, 8, (10, 1))
topk_weights = torch.randn(10, 1)
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map)
self.assertIn("x", kwargs)
self.assertIn("expert_ids", kwargs)
self.assertEqual(kwargs["moe_expert_num"], 8)
def test_token_permutation_dispatch(self):
hidden_states = torch.randn(10, 128)
topk_weights = torch.randn(10, 1)
topk_ids = torch.randint(0, 8, (10, 1))
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
output = self.dispatcher.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, expert_map)
mock_dispatch.assert_called_once()
self.assertEqual(output["group_list_type"],
1) # group_list_type == 1
def test_token_dispatch_with_shared_experts_and_quant(self):
self.shared_experts = MagicMock()
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
torch.tensor(1.0))
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
self.dispatcher.with_quant = False
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
self.hidden_states = torch.randn(10, 128)
self.topk_weights = torch.randn(10, 1)
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
return_value=(torch.randn(10, 128), ) * 5):
self.dispatcher.token_dispatch(self.hidden_states,
self.topk_weights,
torch.randint(0, 8, (10, 1)),
self.row_idx,
torch.tensor(
[0, 1, 2, 3, 4, 5, 6, 7]),
shared_experts=self.shared_experts)
def test_get_combine_mc_kwargs_with_quant(self):
self.dispatcher.with_quant = True
hidden_states = torch.randn(10, 128)
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.output = torch.randint(0, 8, (10, 1))
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
self.assertIn("tp_send_counts", kwargs)
def test_token_combine_with_shared_experts(self):
self.dispatcher.shared_experts = MagicMock()
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
10, 128), torch.tensor(1.0))
self.dispatcher.shared_act = torch.randn(10, 128)
self.dispatcher.with_quant = True
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
self.dispatcher.need_extra_args = True
self.dispatcher.enable_dispatch_v2 = True
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
self.dispatcher.output = torch.randint(0, 8, (10, 1))
self.hidden_states = torch.randn(10, 128)
with patch("torch_npu.npu_moe_distribute_combine_v2",
return_value=torch.randn(10, 128)):
self.dispatcher.token_combine(self.hidden_states)
class TestTokenDispatcherWithAllGather(TestBase):
def setUp(self):
# Mock dependencies
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
"with_quant": False,
}
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
# Mock NPU functions
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
self.mock_moe_init_routing.return_value = (
torch.randn(6, 128), # sorted_hidden_states
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
)
self.patcher_moe_compute_expert_tokens = patch(
'torch_npu.npu_moe_compute_expert_tokens')
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
)
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
[3, 3]) # expert_tokens
self.patcher_moe_finalize_routing = patch(
'torch_npu.npu_moe_finalize_routing')
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
)
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
self.row_idx = torch.arange(10, dtype=torch.int32)
def tearDown(self):
self.patcher_moe_init_routing.stop()
self.patcher_moe_compute_expert_tokens.stop()
self.patcher_moe_finalize_routing.stop()
def test_token_dispatch_without_expert_map(self):
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, self.row_idx, None)
# Verify npu_moe_init_routing is called
self.mock_moe_init_routing.assert_called_once()
args, kwargs = self.mock_moe_init_routing.call_args
self.assertEqual(results["group_list_type"], 0)
def test_token_dispatch_with_quant(self):
kwargs = {
"apply_router_weight_on_input": False,
"top_k": 2,
"max_num_tokens": 100,
"ep_size": 2,
"num_experts": 128,
}
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
results = self.dispatcher_quant.token_dispatch(hidden_states,
topk_weights, topk_ids,
self.row_idx, None)
self.assertEqual(results["group_list_type"], 0)
def test_token_combine_with_expert_map(self):
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify index_add_ is applied correctly
self.assertEqual(final_hidden_states.shape, (3, 128))
def test_token_combine_without_expert_map(self):
self.dispatcher.with_quant = False
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
self.dispatcher.sorted_weights = torch.tensor(
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
self.dispatcher.original_shape = (3, 128)
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
hidden_states = torch.randn(6, 128)
final_hidden_states = self.dispatcher.token_combine(hidden_states)
# Verify npu_moe_finalize_routing is called
self.mock_moe_finalize_routing.assert_called_once()
args, kwargs = self.mock_moe_finalize_routing.call_args
self.assertEqual(final_hidden_states.shape, (3, 128))
def test_token_dispatch_with_router_weight(self):
self.dispatcher.apply_router_weight_on_input = True
hidden_states = torch.randn(3, 128)
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
topk_ids = torch.tensor([[0], [1], [2]])
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
topk_ids, None)
self.assertEqual(results["hidden_states"].shape, (6, 128))
class TestTokenDispatcherWithAll2AllV(TestBase):
def setUp(self):
# Patch properties
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
'ep_group',
new_callable=PropertyMock,
return_value=MagicMock())
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
'ep_rank',
new_callable=PropertyMock,
return_value=0)
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
'ep_size',
new_callable=PropertyMock,
return_value=2)
self.addCleanup(patcher1.stop)
self.addCleanup(patcher2.stop)
self.addCleanup(patcher3.stop)
self.mock_ep_group_prop = patcher1.start()
self.mock_ep_rank_prop = patcher2.start()
self.mock_ep_size_prop = patcher3.start()
# Mock torch_npu.npu_moe_token_permute
patcher4 = patch('torch_npu.npu_moe_token_permute')
self.mock_npu_moe_token_permute = patcher4.start()
self.addCleanup(patcher4.stop)
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
torch.arange(16))
# Mock torch_npu.npu_moe_token_unpermute
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
self.mock_npu_moe_token_unpermute = patcher5.start()
self.addCleanup(patcher5.stop)
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
# Mock async_all_to_all
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
self.mock_async_all_to_all = patcher6.start()
self.addCleanup(patcher6.stop)
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
MagicMock())
# Mock gather_from_sequence_parallel_region
patcher7 = patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
)
self.mock_gather_from_sequence_parallel_region = patcher7.start()
self.addCleanup(patcher7.stop)
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
# Mock torch.histc
patcher8 = patch('torch.histc')
self.mock_histc = patcher8.start()
self.addCleanup(patcher8.stop)
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
dtype=torch.int64)
# Mock torch.npu.current_device
patcher9 = patch('torch.npu.current_device')
self.mock_current_device = patcher9.start()
self.addCleanup(patcher9.stop)
self.mock_current_device.return_value = 'cpu'
# Mock torch_npu.npu_dynamic_quant
patcher10 = patch('torch_npu.npu_dynamic_quant')
self.mock_npu_dynamic_quant = patcher10.start()
self.addCleanup(patcher10.stop)
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
torch.randn(16))
# Mock torch_npu.npu_moe_init_routing_v2
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
self.mock_npu_moe_init_routing_v2 = patcher11.start()
self.addCleanup(patcher11.stop)
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
16, 16), torch.arange(16), None, torch.randn(16))
# Mock torch.repeat_interleave
patcher12 = patch('torch.repeat_interleave')
self.mock_repeat_interleave = patcher12.start()
self.addCleanup(patcher12.stop)
self.mock_repeat_interleave.return_value = torch.arange(16)
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2,
with_quant=False)
self.row_idx = torch.arange(10, dtype=torch.int32)
def test_token_dispatch(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
def test_token_combine(self):
self.dispatcher.hidden_shape = (8, 16)
self.dispatcher.hidden_shape_before_permute = (8, 16)
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
8)
self.dispatcher.topk_weights = torch.rand(8, 4)
self.dispatcher.input_splits = [4, 4]
self.dispatcher.output_splits = [4, 4]
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
16)
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
[[2, 2], [2, 2]], dtype=torch.int64)
expert_output = torch.randn(16, 16)
output = self.dispatcher.token_combine(expert_output)
self.assertIsNotNone(output)
self.assertEqual(output.shape, (8, 16))
def test_token_dispatch_with_quant(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
def test_token_dispatch_with_quant_no_active_tokens(self):
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
num_experts=4,
num_local_experts=2)
self.mock_repeat_interleave.return_value = torch.tensor(
[], dtype=torch.long)
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
with_quant=True)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertIsNotNone(result["dynamic_scale"])
self.assertEqual(result["group_list_type"], 1)
def test_token_dispatch_with_log2phy(self):
hidden_states = torch.randn(8, 16)
topk_weights = torch.rand(8, 4)
topk_ids = torch.randint(0, 4, (8, 2)).long()
expert_map = torch.tensor([0, 1, 2, 3])
log2phy = torch.tensor([1, 0, 3, 2])
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
[0, 1], dtype=torch.int32)
self.dispatcher.local_expert_indices = [0, 1]
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=self.row_idx,
expert_map=expert_map,
log2phy=log2phy)
self.assertIsNotNone(result["hidden_states"])
self.assertIsNotNone(result["group_list"])
self.assertEqual(result["group_list_type"], 1)
class TestDispatcherRegistry(TestBase):
def setUp(self):
_Dispatchers.clear()
def tearDown(self):
_Dispatchers.clear()
def test_register_and_get_token_dispatcher(self):
mock_dispatcher = MagicMock()
mock_dispatcher.__class__.__name__ = "MockDispatcher"
_register_token_dispatcher(mock_dispatcher)
self.assertIn("MockDispatcher", _Dispatchers)
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
self.assertIs(retrieved_dispatcher, mock_dispatcher)
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
self, mock_register, mock_allgather_class):
kwargs = {"top_k": 2, "num_experts": 8}
mock_instance = MagicMock()
mock_allgather_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
setup_token_dispatchers(ep_size=1, **kwargs)
mock_allgather_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
self, mock_register, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
mock_instance = MagicMock()
mock_all2allv_class.return_value = mock_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
setup_token_dispatchers(ep_size=2, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_register.assert_called_once_with(mock_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_all2allv_instance = MagicMock()
mock_mc2_instance = MagicMock()
mock_all2allv_class.return_value = mock_all2allv_instance
mock_mc2_class.return_value = mock_mc2_instance
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_called_once_with(**kwargs)
mock_mc2_class.assert_called_once_with(**kwargs)
self.assertEqual(mock_register.call_count, 2)
mock_register.assert_any_call(mock_all2allv_instance)
mock_register.assert_any_call(mock_mc2_instance)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
)
@patch(
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
)
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
self, mock_register, mock_mc2_class, mock_all2allv_class):
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
mock_existing_all2allv = MagicMock()
mock_existing_mc2 = MagicMock()
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
setup_token_dispatchers(ep_size=16, **kwargs)
mock_all2allv_class.assert_not_called()
mock_mc2_class.assert_not_called()
mock_register.assert_not_called()
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
mock_existing_all2allv)
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
mock_existing_mc2)

View File

@@ -0,0 +1,232 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/lora/test_layers.py
import unittest
from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.ops.vocab_parallel_embedding import (
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
class TestCustomVocabParallelEmbedding(unittest.TestCase):
def setUp(self):
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
def _create_layer(self):
# Patch methods and dependencies for VocabParallelEmbedding
mock_group = MagicMock()
mock_group.world_size = 2
mock_group.rank_in_group = 0
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
# Create an instance of VocabParallelEmbedding
layer = AscendVocabParallelEmbedding(
num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
org_num_embeddings=self.org_num_embeddings,
padding_size=self.padding_size,
quant_config=None, # Mock quantization config
prefix="")
layer.shard_indices = MagicMock()
layer.shard_indices.org_vocab_start_index = 10
layer.shard_indices.org_vocab_end_index = 20
layer.shard_indices.num_org_vocab_padding = 5
layer.shard_indices.added_vocab_start_index = 30
layer.shard_indices.added_vocab_end_index = 40
# Mock the quantization method
layer.quant_method.embedding = MagicMock(
side_effect=lambda _, x: torch.randn(x.shape[0], self.
embedding_dim))
return layer
def test_get_masked_input_and_mask(self):
"""Test the mask and offset calculation helper function."""
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45])
masked_input, mask = layer._get_masked_input_and_mask(
input_,
org_vocab_start_index=10,
org_vocab_end_index=20,
num_org_vocab_padding=5,
added_vocab_start_index=30,
added_vocab_end_index=40)
expected_mask = torch.tensor([True, False, True, False, True])
self.assertTrue(
torch.equal(mask, expected_mask),
f"Mask mismatch. Expected {expected_mask}, got {mask}")
expected_masked = torch.tensor([0, 5, 0, 20, 0])
self.assertTrue(
torch.equal(masked_input, expected_masked),
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
)
def test_forward_with_tp_size_1(self):
"""Test forward pass without tensor parallelism."""
# Create a fresh mock embedding with tp_size=1
layer = self._create_layer()
layer.tp_size = 1
layer.quant_method.embedding = MagicMock(
return_value=torch.randn(3, layer.embedding_dim))
input_ = torch.tensor([1, 2, 3])
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp1:
output = layer.forward(input_)
# Should just pass through without masking
layer.quant_method.embedding.assert_called_once_with(
layer, input_.long())
self.assertEqual(output.shape, (3, layer.embedding_dim))
# Verify all_reduce was called once
mock_reduce_tp1.assert_called_once()
def test_forward_with_tp(self):
layer = self._create_layer()
layer.tp_size = 2
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x) as mock_reduce_tp:
# Call the forward method
output = layer.forward(input_)
# Check that masking was applied correctly
layer.quant_method.embedding.assert_called_once()
called_input = layer.quant_method.embedding.call_args[0][1]
expected_input = torch.tensor([5, 20]) # after offset calculation
self.assertTrue(torch.all(called_input == expected_input))
# Check that all reduce was called
mock_reduce_tp.assert_called_once()
self.assertEqual(output.shape, (2, self.embedding_dim))
def test_forward_with_invalid_vocab(self):
"""Test that invalid vocab indices are properly masked out."""
# Create a fresh embedding layer
layer = self._create_layer()
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
# Create predictable mock output
mock_output = torch.randn(5, self.embedding_dim)
layer.quant_method.embedding = MagicMock(
return_value=mock_output.clone())
# Patch tensor_model_parallel_all_reduce to mock its behavior
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
# Check that invalid positions (0, 2, 4) were zeroed out
self.assertTrue(torch.all(output[0] == 0))
self.assertTrue(torch.all(output[2] == 0))
self.assertTrue(torch.all(output[4] == 0))
self.assertTrue(torch.all(output[1] == mock_output[1]))
self.assertTrue(torch.all(output[3] == mock_output[3]))
self.assertEqual(output.shape, (5, self.embedding_dim))
def test_output_shape(self):
"""Test that output shape is correct."""
# Create a fresh embedding layer
layer = self._create_layer()
test_cases = [
(torch.tensor([15]), (1, self.embedding_dim)),
(torch.tensor([15, 35]), (2, self.embedding_dim)),
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
]
for input_, expected_shape in test_cases:
with self.subTest(input=input_):
with patch(
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
side_effect=lambda x: x):
# Call the forward method
output = layer.forward(input_)
self.assertEqual(output.shape, expected_shape)
class TestAscendLogitsProcessor(unittest.TestCase):
def setUp(self):
self.vocab_size = 50
self.num_embeddings = 50
self.embedding_dim = 10
self.org_num_embeddings = 40
self.padding_size = 8
self.mock_group = MagicMock()
self.mock_group.world_size = 2
self.mock_group.rank_in_group = 0
self.mock_ascend_config = MagicMock()
self.mock_quant_method = MagicMock()
self.mock_quant_method.apply = MagicMock(
return_value=torch.randn(1, self.vocab_size))
self.patches = [
patch("vllm_ascend.ascend_config.get_ascend_config",
return_value=self.mock_ascend_config),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
return_value=self.mock_group),
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
return_value=True),
patch(
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
return_value=torch.randn(1, self.vocab_size))
]
for p in self.patches:
p.start()
def tearDown(self):
for p in self.patches:
p.stop()
def test_create_processor(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
self.assertEqual(processor.vocab_size, self.vocab_size)
def test_get_logits(self):
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
embedding_dim=self.embedding_dim,
prefix="lm_head")
lmhead.quant_method = self.mock_quant_method
lmhead.quant_method.apply = self.mock_quant_method.apply
hidden_state = torch.randn(1, self.org_num_embeddings)
processor._get_logits(hidden_state, lmhead)
self.mock_quant_method.apply.assert_called_once()

View File

@@ -0,0 +1,112 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock, patch
import torch
from vllm.distributed.parallel_state import GroupCoordinator
from tests.ut.base import TestBase
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
GroupCoordinatorPatch
class TestPatchDistributed(TestBase):
def setUp(self):
self.mock_group_ranks = [[0, 1]]
self.mock_local_rank = 0
self.mock_backend = "hccl"
self.mock_use_device_comm = True
patcher_get_rank = patch("torch.distributed.get_rank", return_value=0)
patcher_new_group = patch("torch.distributed.new_group",
return_value=MagicMock())
patcher_is_cuda_alike = patch(
"vllm.platforms.current_platform.is_cuda_alike", return_value=True)
patcher_device_comm_cls = patch(
"vllm.distributed.parallel_state.resolve_obj_by_qualname",
return_value=MagicMock())
self.mock_get_rank = patcher_get_rank.start()
self.mock_new_group = patcher_new_group.start()
self.mock_is_cuda_alike = patcher_is_cuda_alike.start()
self.mock_resolve_obj = patcher_device_comm_cls.start()
self.addCleanup(patcher_get_rank.stop)
self.addCleanup(patcher_new_group.stop)
self.addCleanup(patcher_is_cuda_alike.stop)
self.addCleanup(patcher_device_comm_cls.stop)
self.group_coordinator = GroupCoordinatorPatch(
group_ranks=self.mock_group_ranks,
local_rank=self.mock_local_rank,
torch_distributed_backend=self.mock_backend,
use_device_communicator=self.mock_use_device_comm)
def test_GroupCoordinator_patched(self):
self.assertIs(GroupCoordinator, GroupCoordinatorPatch)
def test_all_to_all_returns_input_when_world_size_1(self):
self.group_coordinator.world_size = 1
input_tensor = torch.randn(2, 3)
output = self.group_coordinator.all_to_all(input_tensor)
self.assertTrue(torch.equal(output, input_tensor))
def test_all_to_all_raises_assertion_on_invalid_scatter_dim(self):
input_tensor = torch.randn(2, 3)
with self.assertRaises(AssertionError) as cm:
self.group_coordinator.all_to_all(input_tensor, scatter_dim=2)
self.assertIn("Invalid scatter dim", str(cm.exception))
def test_all_to_all_raises_assertion_on_invalid_gather_dim(self):
input_tensor = torch.randn(2, 3)
with self.assertRaises(AssertionError) as cm:
self.group_coordinator.all_to_all(input_tensor, gather_dim=2)
self.assertIn("Invalid gather dim", str(cm.exception))
def test_all_to_all_calls_device_communicator_with_correct_args(self):
mock_communicator = MagicMock()
self.group_coordinator.device_communicator = mock_communicator
input_tensor = torch.randn(2, 3)
scatter_dim = 0
gather_dim = 1
scatter_sizes = [1, 1]
gather_sizes = [1, 1]
self.group_coordinator.all_to_all(input_tensor,
scatter_dim=scatter_dim,
gather_dim=gather_dim,
scatter_sizes=scatter_sizes,
gather_sizes=gather_sizes)
mock_communicator.all_to_all.assert_called_once_with(
input_tensor, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
def test_all_to_all_calls_device_communicator_without_sizes(self):
mock_communicator = MagicMock()
self.group_coordinator.device_communicator = mock_communicator
input_tensor = torch.randn(2, 3)
scatter_dim = 0
gather_dim = 1
self.group_coordinator.all_to_all(input_tensor,
scatter_dim=scatter_dim,
gather_dim=gather_dim)
mock_communicator.all_to_all.assert_called_once_with(
input_tensor, scatter_dim, gather_dim, None, None)

View File

@@ -0,0 +1,167 @@
from importlib import reload
import pytest
import torch
import vllm
from pytest_mock import MockerFixture
import vllm_ascend.envs as envs_ascend
from tests.ut.base import PytestBase
from vllm_ascend.patch.worker.patch_common import patch_linear
class TestAscendRowParallelLinear(PytestBase):
def init_row_parallel_linear(self, mocker: MockerFixture):
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
return_value=None,
)
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
return patch_linear.AscendRowParallelLinear(
input_size=128,
output_size=256,
)
@pytest.mark.parametrize(
"version, expected",
[
("1.0.0", 1),
("2.1.0", 1),
],
)
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
mock_group = mocker.MagicMock()
backend = mocker.MagicMock()
backend.get_hccl_comm_name = lambda x: x
mock_group._get_backend = lambda x: backend
mock_group.get_hccl_comm_name = lambda x: x
mocker.patch("torch.distributed.get_rank", return_value=1)
mocker.patch(
"torch.distributed.get_global_rank",
return_value=0,
)
mocker.patch("torch.__version__", new=version)
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
mock_group)
assert hcomm_info == expected
@pytest.mark.parametrize(
"skip_bias_add, return_bias, bias, expected",
[
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
(
True,
True,
torch.tensor(1.0),
(torch.tensor(14.0), torch.tensor(1.0)),
),
],
)
def test_forward(
self,
skip_bias_add,
return_bias,
bias,
expected,
mocker: MockerFixture,
):
mocker_tp_group = mocker.MagicMock()
mocker_tp_group.device_group = mocker.MagicMock()
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["tp_rank"] = 0
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
row_parallel_linear.__dict__["return_bias"] = return_bias
row_parallel_linear.__dict__["bias"] = bias
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
row_parallel_linear.__dict__[
"calc_output"] = lambda x: x.matmul( # noqa
torch.tensor([1.0, 2.0]))
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
if isinstance(ret, tuple):
assert torch.allclose(ret[0], expected[0])
if ret[1] is None:
assert ret[1] == expected[1]
else:
assert torch.allclose(ret[1], expected[1])
else:
assert torch.allclose(ret, expected)
@pytest.mark.parametrize(
"input_is_parallel, expected",
[
(True, torch.tensor([10.0, 2.0])),
(False, torch.tensor([10.0])),
],
)
def test_calc_input(
self,
input_is_parallel,
expected,
mocker: MockerFixture,
):
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
input_tensor = torch.Tensor([10, 2])
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
return_value=0,
)
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
return_value=[torch.Tensor([10]),
torch.Tensor([2])],
)
input_parallel = row_parallel_linear.calc_input(input_tensor)
assert torch.allclose(input_parallel, expected)
@pytest.mark.parametrize(
"reduce_results, tp_size, expected",
[
(True, 2, torch.tensor(56.0)),
(True, 1, torch.tensor(14.0)),
(False, 2, torch.tensor(14.0)),
],
)
def test_calc_output(
self,
reduce_results,
tp_size,
expected,
mocker: MockerFixture,
):
quant_method = mocker.MagicMock()
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
torch.tensor([1.0, 2.0]))
row_parallel_linear = self.init_row_parallel_linear(mocker)
row_parallel_linear.__dict__["reduce_results"] = reduce_results
row_parallel_linear.__dict__["tp_size"] = tp_size
row_parallel_linear.__dict__["quant_method"] = quant_method
row_parallel_linear.__dict__["tp_rank"] = 0
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
mocker.patch(
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
)
mocker.patch(
"torch_npu.npu_mm_all_reduce_base",
side_effect=lambda input_, weight, hccl_info, bias: input_.
matmul( # noqa
torch.tensor([4.0, 8.0])),
) # noqa
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
assert torch.allclose(ret, expected)
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
mocker.patch.object(envs_ascend,
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
new=True)
reload(patch_linear)
assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
patch_linear.AscendRowParallelLinear)

View File

@@ -0,0 +1,77 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import MagicMock
import torch
from tests.ut.base import TestBase
from vllm_ascend.patch.worker.patch_common.patch_minicpm import forward
class TestPatchMiniCPM(TestBase):
def setUp(self):
self.mock_self = MagicMock()
self.mock_self.q_size = 128
self.mock_self.kv_size = 128
self.mock_self.qkv_proj = MagicMock()
self.mock_self.rotary_emb = MagicMock()
self.mock_self.attn = MagicMock()
self.mock_self.o_proj = MagicMock()
self.positions = torch.tensor([1, 2, 3])
self.hidden_states = torch.randn(3, 256)
self.mock_qkv = torch.randn(3, 384)
self.mock_q = self.mock_qkv[:, :128]
self.mock_k = self.mock_qkv[:, 128:256]
self.mock_v = self.mock_qkv[:, 256:]
self.mock_self.qkv_proj.return_value = (self.mock_qkv, None)
self.mock_self.rotary_emb.return_value = (self.mock_q, self.mock_k)
self.mock_self.attn.return_value = torch.randn(3, 256)
self.mock_self.o_proj.return_value = (torch.randn(3, 256), None)
def test_forward_patched(self):
from vllm.model_executor.models.minicpm import MiniCPMAttention
self.assertIs(MiniCPMAttention.forward, forward)
def test_forward_function(self):
result = forward(self.mock_self, self.positions, self.hidden_states)
self.mock_self.qkv_proj.assert_called_once_with(self.hidden_states)
args, _ = self.mock_self.rotary_emb.call_args
self.assertEqual(len(args), 3)
self.assertTrue(torch.equal(args[0], self.positions))
self.assertTrue(torch.equal(args[1], self.mock_q))
self.assertTrue(torch.equal(args[2], self.mock_k))
args, _ = self.mock_self.attn.call_args
self.assertEqual(len(args), 3)
self.assertTrue(torch.equal(args[0], self.mock_q))
self.assertTrue(torch.equal(args[1], self.mock_k))
self.assertTrue(torch.equal(args[2], self.mock_v))
self.mock_self.o_proj.assert_called_once_with(
self.mock_self.attn.return_value)
self.assertEqual(result.shape, (3, 256))
self.assertTrue(
torch.equal(result, self.mock_self.o_proj.return_value[0]))

View File

@@ -0,0 +1,134 @@
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
class MockRMSNorm:
def __init__(self, hidden_size: int, **extra_args):
self.hidden_size = hidden_size
self.weight = torch.ones(hidden_size)
self.input_scale = 1.0
self.input_offset = 0.0
self.variance_epsilon = 1e-6
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
self.ignore_anti = extra_args.get('ignore_anti', True)
class TestFuncWrapper(TestBase):
def test_wrapper_rmsnorm_init(self):
@wrapper_rmsnorm_init
def init(self, hidden_size: int, **extra_args) -> None:
self.hidden_size = hidden_size
hidden_size = 128
extra_args = {'arg1': 'value1'}
rms_norm = MockRMSNorm(hidden_size, **extra_args)
init(rms_norm, hidden_size, **extra_args)
self.assertTrue(hasattr(rms_norm, 'ignore_anti'))
self.assertTrue(rms_norm.ignore_anti)
self.assertTrue(hasattr(rms_norm, 'bias'))
self.assertIsInstance(rms_norm.bias, torch.nn.Parameter)
self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size]))
self.assertFalse(rms_norm.bias.requires_grad)
@patch('torch_npu._npu_quant_rms_norm')
def test_wrapper_rmsnorm_forward_oot_with_residual(
self, mock_npu_quant_rms_norm):
hidden_size = 128
x = torch.randn(hidden_size)
residual = torch.randn(hidden_size)
expected_out = torch.randn(hidden_size)
mock_npu_quant_rms_norm.return_value = (expected_out, residual)
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x, residual
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = False
output, res = forward_oot(rms_norm, x, residual)
mock_npu_quant_rms_norm.assert_called_once()
args, kwargs = mock_npu_quant_rms_norm.call_args
self.assertTrue(torch.equal(args[1], rms_norm.weight))
self.assertTrue(torch.equal(args[2], rms_norm.bias))
self.assertEqual(args[3], rms_norm.input_scale)
self.assertEqual(args[4], rms_norm.input_offset)
self.assertEqual(args[5], rms_norm.variance_epsilon)
self.assertTrue(torch.equal(res, residual))
@patch('torch_npu._npu_quant_rms_norm')
def test_wrapper_rmsnorm_forward_oot_without_residual(
self, mock_npu_quant_rms_norm):
hidden_size = 128
x = torch.randn(hidden_size)
expected_out = torch.randn(hidden_size)
mock_npu_quant_rms_norm.return_value = expected_out
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = False
output = forward_oot(rms_norm, x)
mock_npu_quant_rms_norm.assert_called_once()
args, kwargs = mock_npu_quant_rms_norm.call_args
self.assertTrue(torch.equal(args[0], x))
self.assertTrue(torch.equal(args[1], rms_norm.weight))
self.assertTrue(torch.equal(args[2], rms_norm.bias))
self.assertEqual(args[3], rms_norm.input_scale)
self.assertEqual(args[4], rms_norm.input_offset)
self.assertEqual(args[5], rms_norm.variance_epsilon)
self.assertTrue(torch.equal(output, expected_out))
def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self):
hidden_size = 128
x = torch.randn(hidden_size)
residual = torch.randn(hidden_size)
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x, residual
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = True
output, res = forward_oot(rms_norm, x, residual)
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
self.assertTrue(torch.equal(res, residual))
def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self):
hidden_size = 128
x = torch.randn(hidden_size)
@wrapper_rmsnorm_forward_oot
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
return x
rms_norm = MockRMSNorm(hidden_size)
rms_norm.ignore_anti = True
output = forward_oot(rms_norm, x)
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))

View File

@@ -0,0 +1,232 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from tests.ut.base import TestBase
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
AscendQuantConfig)
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
class TestAscendQuantConfig(TestBase):
def setUp(self):
self.sample_config = {
"weight": "INT8",
"fa_quant_type": "C8",
"kv_quant_type": "C8",
"layer1.weight": "INT8",
"layer2.weight": "FLOAT",
"fused_layer.weight": "FLOAT",
"fused_layer.shard1.weight": "FLOAT",
"fused_layer.shard2.weight": "FLOAT",
"shard1.weight": "FLOAT",
"shard2.weight": "FLOAT",
}
self.ascend_config = AscendQuantConfig(self.sample_config)
self.ascend_config.packed_modules_mapping = None
def test_init(self):
self.assertEqual(self.ascend_config.quant_description,
self.sample_config)
def test_repr(self):
repr_str = repr(self.ascend_config)
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
def test_get_name(self):
self.assertEqual(AscendQuantConfig.get_name(),
ASCEND_QUANTIZATION_METHOD)
def test_get_supported_act_dtypes(self):
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
self.assertEqual(len(supported_dtypes), 3)
def test_get_min_capability(self):
with self.assertRaises(NotImplementedError):
AscendQuantConfig.get_min_capability()
def test_get_config_filenames(self):
filenames = AscendQuantConfig.get_config_filenames()
self.assertEqual(filenames, ["quant_model_description.json"])
def test_from_config(self):
config = AscendQuantConfig.from_config(self.sample_config)
self.assertIsInstance(config, AscendQuantConfig)
self.assertEqual(config.quant_description, self.sample_config)
@patch('torch.npu.is_available')
def test_override_quantization_method(self, mock_is_available):
# Test when NPU is available
mock_is_available.return_value = True
result = AscendQuantConfig.override_quantization_method(None, None)
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
# Test when NPU is not available
mock_is_available.return_value = False
result = AscendQuantConfig.override_quantization_method(None, None)
self.assertIsNone(result)
def test_get_quant_method_for_linear(self):
linear_layer = MagicMock(spec=LinearBase)
# Test skipped layer
with patch.object(self.ascend_config,
'is_layer_skipped_ascend',
return_value=True):
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIsInstance(method, UnquantizedLinearMethod)
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
self.assertIs(method, mock_ascend_linear.return_value)
mock_ascend_linear.assert_called_once_with(
self.ascend_config, ".attn",
self.ascend_config.packed_modules_mapping)
def test_get_quant_method_for_attention(self):
attention_layer = MagicMock(spec=Attention)
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with fa_quant_type
method = self.ascend_config.get_quant_method(
attention_layer, ".attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with kv_quant_type
modified_config = {"kv_quant_type": "C8"}
config = AscendQuantConfig(modified_config)
config.packed_modules_mapping = None
method = config.get_quant_method(attention_layer, "attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
# Test skipped layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
# Test quantized layer
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
method = self.ascend_config.get_quant_method(
fused_moe_layer, "moe_layer")
self.assertIs(method, mock_ascend_moe.return_value)
def test_is_layer_skipped_ascend(self):
# Test non-fused layer that should be quantized
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
# Test non-fused layer that should be skipped
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
# Test fused layer
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
self.assertTrue(
self.ascend_config.is_layer_skipped_ascend("fused_layer",
fused_mapping))
# Test inconsistent fused layer shards
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
config = AscendQuantConfig(bad_config)
with self.assertRaises(ValueError):
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
def test_get_scaled_act_names(self):
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
class TestAscendKVCacheMethod(TestBase):
def setUp(self):
# Setup common test fixtures
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
self.mock_quant_config.quant_description = {"some_config": "value"}
self.prefix = "attention_layer"
# Mock the quantizer and quant_method
self.mock_quantizer = MagicMock()
self.mock_quant_method = MagicMock()
# Patch the AscendQuantizer
self.quantizer_patcher = patch(
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
return_value=self.mock_quantizer)
self.mock_get_quantizer = self.quantizer_patcher.start()
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
# Create instance
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
self.prefix)
def tearDown(self):
self.quantizer_patcher.stop()
def test_init(self):
"""Test initialization with proper quantizer setup."""
self.mock_get_quantizer.assert_called_once_with(
self.mock_quant_config.quant_description, self.prefix)
self.mock_quantizer.build_attention_method.assert_called_once()
def test_create_weights(self):
"""Test create_weights delegates to quant_method."""
mock_layer = MagicMock()
self.kv_cache_method.create_weights(mock_layer)
self.mock_quant_method.create_weights.assert_called_once_with(
mock_layer)
def test_process_weights_after_loading_with_method(self):
"""Test process_weights when quant_method has the method."""
mock_layer = MagicMock()
self.kv_cache_method.process_weights_after_loading(mock_layer)
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
mock_layer)
def test_process_weights_after_loading_without_method(self):
"""Test process_weights when quant_method lacks the method."""
# Reset mock to remove the method
del self.mock_quant_method.process_weights_after_loading
mock_layer = MagicMock()
# Should not raise exception
self.kv_cache_method.process_weights_after_loading(mock_layer)
def test_apply_delegation(self):
"""Test apply properly delegates to quant_method."""
mock_layer = MagicMock()
mock_query = torch.randn(1, 32, 128)
mock_key = torch.randn(1, 32, 128)
mock_value = torch.randn(1, 32, 128)
mock_kv_cache = MagicMock()
mock_attn_metadata = MagicMock()
mock_scale = 1.0
mock_output = torch.zeros(1, 32, 128)
mock_attn_type = MagicMock()
expected_result = torch.randn(1, 32, 128)
self.mock_quant_method.apply.return_value = expected_result
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
mock_value, mock_kv_cache,
mock_attn_metadata, mock_attn_type,
mock_scale, mock_output)
self.mock_quant_method.apply.assert_called_once_with(
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
self.assertTrue(torch.equal(result, expected_result))

View File

@@ -0,0 +1,145 @@
from unittest.mock import MagicMock, patch
from tests.ut.base import TestBase
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
W4A8DYNAMICQuantizer,
W8A8Quantizer)
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
class TestGetQuantizer(TestBase):
def setUp(self):
# Setup common test fixtures
self.supported_types = {
'INT8': MagicMock(_instance=None),
'FP16': MagicMock(_instance=None),
'C8': MagicMock(_instance=None)
}
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
self.mock_quant_config.quant_description = {"some_config": "value"}
def tearDown(self):
# Restore original supported types
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
def test_get_quantizer_fa(self):
"""Test successful quantizer retrieval for different cases."""
# Setup
quant_description = {'fa_quant_type': 'C8'}
prefix = '.attn'
expected_type = 'C8'
with patch.dict(
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
SUPPORT_ASCEND_QUANTIZER_TYPE):
result = VLLMAscendQuantizer.get_quantizer(
quant_description,
prefix,
packed_modules_mapping={"some": "mapping"})
# Verify
self.assertIsNotNone(result)
self.assertEqual(result,
self.supported_types[expected_type]._instance)
self.supported_types[expected_type].assert_called_once_with(
quant_description)
def test_get_quantizer_kv(self):
"""Test successful quantizer retrieval for different cases."""
# Setup
quant_description = {'kv_quant_type': 'C8'}
prefix = '.attn'
expected_type = 'C8'
with patch.dict(
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
SUPPORT_ASCEND_QUANTIZER_TYPE):
result = VLLMAscendQuantizer.get_quantizer(
quant_description,
prefix,
packed_modules_mapping={"some": "mapping"})
# Verify
self.assertIsNotNone(result)
self.assertEqual(result,
self.supported_types[expected_type]._instance)
self.supported_types[expected_type].assert_called_once_with(
quant_description)
def test_get_quantizer_linear(self):
"""Test successful quantizer retrieval for different cases."""
# Setup
quant_description = {'linear_type': 'INT8'}
prefix = 'nothing'
expected_type = 'INT8'
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
return_value=expected_type), \
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
result = VLLMAscendQuantizer.get_quantizer(
quant_description,
prefix,
packed_modules_mapping={"some": "mapping"})
# Verify
self.assertIsNotNone(result)
self.assertEqual(result,
self.supported_types[expected_type]._instance)
self.supported_types[expected_type].assert_called_once_with(
quant_description)
class TestW8A8Quantizer(TestBase):
def setUp(self):
self.quantizer = W8A8Quantizer(quant_description={})
def test_build_linear_method(self):
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_linear_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)
def test_build_moe_method(self):
with patch(
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_moe_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)
def test_build_attention_method(self):
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_attention_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)
class TestW4A8DYNAMICQuantizer(TestBase):
def setUp(self):
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
def test_build_linear_method(self):
with patch(
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
return_value=MagicMock()) as mock_linear:
result = self.quantizer.build_linear_method()
mock_linear.assert_called_once_with()
self.assertIsInstance(result, MagicMock)
def test_build_moe_method(self):
with patch(
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
return_value=MagicMock()) as mock_fused_moe:
result = self.quantizer.build_moe_method()
mock_fused_moe.assert_called_once_with()
self.assertIsInstance(result, MagicMock)

View File

@@ -0,0 +1,166 @@
import copy
from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.w4a8_dynamic import (
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
def setUp(self):
self.method = AscendW4A8DynamicLinearMethod()
self.method.group_size = 8
def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale"].shape, (32, 1))
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset"].shape, (32, 1))
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
experts = 8
input_size = 16
output_size = 56
group_size = 2
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
get_current_vllm_config):
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
def test_get_weight(self):
# old quant version w4a8 weight
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, 2 * self.input_size, self.output_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, self.input_size, self.output_size))
def test_get_dynamic_quant_param(self):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size,
self.input_size // self.group_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
# old quant version weight
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
new_layer = copy.deepcopy(layer)
mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))

View File

@@ -0,0 +1,930 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
select_experts)
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod,
fused_experts, fused_experts_310p,
quant_per_tensor)
class TestQuantPerTensor(TestBase):
@patch("torch_npu.npu_quantize")
def test_quant_per_tensor(self, mock_npu_quantize):
in_tensor = torch.randn(32, 128)
input_scale = torch.tensor(0.1)
input_offset = torch.tensor(0)
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
mock_npu_quantize.return_value = expected_output
output = quant_per_tensor(in_tensor, input_scale, input_offset)
mock_npu_quantize.assert_called_once_with(
in_tensor,
input_scale,
input_offset,
torch.qint8,
-1,
False,
)
self.assertTrue(torch.equal(output, expected_output))
class TestAscendW8A8LinearMethod(TestBase):
def setUp(self):
self.method = AscendW8A8LinearMethod()
def test_get_weight(self):
weight = self.method.get_weight(10, 20)
self.assertEqual(weight['weight'].dtype, torch.int8)
self.assertEqual(weight['weight'].shape, (20, 10))
def test_get_pertensor_param(self):
params = self.method.get_pertensor_param(torch.bfloat16)
self.assertEqual(params['input_scale'].dtype, torch.bfloat16)
self.assertEqual(params['input_offset'].dtype, torch.int8)
self.assertEqual(params['input_scale'].shape, (1, ))
self.assertEqual(params['input_offset'].shape, (1, ))
def test_get_perchannel_param(self):
params = self.method.get_perchannel_param(10, torch.bfloat16)
self.assertEqual(params['quant_bias'].dtype, torch.int32)
self.assertEqual(params['deq_scale'].dtype, torch.float32)
self.assertEqual(params['weight_scale'].dtype, torch.bfloat16)
self.assertEqual(params['weight_offset'].dtype, torch.bfloat16)
self.assertEqual(params['quant_bias'].shape, (10, ))
self.assertEqual(params['deq_scale'].shape, (10, ))
self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
mock_quant_per_tensor):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
x = torch.randn(32, 128)
bias = torch.randn(256)
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
x.shape,
dtype=torch.int8)
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_is_int8(self, mock_npu_quant_matmul):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
bias = torch.randn(256)
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
@patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p):
layer = MagicMock()
layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
bias = torch.randn(256)
expected_y_output = torch.randn(32, 256)
mock_npu_quant_matmul.return_value = expected_y_output
output = self.method.apply(layer, x, bias)
expected_y_output += bias
self.assertTrue(torch.equal(output, expected_y_output))
@patch('torch_npu.npu_format_cast')
def test_process_weights_after_loading(self, mock_npu_format_cast):
layer = MagicMock()
layer.weight.data = torch.randn(128, 256)
layer.input_scale.data = torch.tensor([0.1])
layer.input_offset.data = torch.tensor([0])
layer.deq_scale = torch.tensor([0.5])
layer.weight_scale.data = torch.randn(128, 1)
layer.weight_offset.data = torch.randn(128, 1)
mock_npu_format_cast.return_value = MagicMock
self.method.process_weights_after_loading(layer)
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
self.assertTrue(
torch.equal(layer.aclnn_input_offset.data, expected_offset))
self.assertFalse(layer.aclnn_input_offset.requires_grad)
self.assertFalse(layer.deq_scale.requires_grad)
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.moe_method = AscendW8A8FusedMoEMethod()
self.num_experts = 4
self.intermediate_size = 64
self.hidden_size = 128
self.dtype = torch.float32
def test_init(self):
self.assertTrue(self.moe_method.transpose_weight)
def test_get_weight(self):
weights = self.moe_method.get_weight(
num_experts=self.num_experts,
intermediate_size_per_partition=self.intermediate_size,
hidden_sizes=self.hidden_size,
params_dtype=self.dtype)
assert "w13_weight" in weights, f"w13_weight not in {weights}"
assert "w2_weight" in weights, f"w2_weight not in {weights}"
self.assertEqual(
weights["w13_weight"].shape,
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
self.assertEqual(
weights["w2_weight"].shape,
(self.num_experts, self.hidden_size, self.intermediate_size))
self.assertEqual(weights["w13_weight"].dtype, torch.int8)
self.assertEqual(weights["w2_weight"].dtype, torch.int8)
self.assertFalse(weights["w13_weight"].requires_grad)
self.assertFalse(weights["w2_weight"].requires_grad)
def test_get_dynamic_quant_param(self):
quant_params = self.moe_method.get_dynamic_quant_param(
num_experts=self.num_experts,
intermediate_size_per_partition=self.intermediate_size,
hidden_sizes=self.hidden_size,
params_dtype=self.dtype)
expected_params = [
"w13_weight_scale", "w13_weight_offset", "w2_weight_scale",
"w2_weight_offset", "w2_deq_scale", "w13_deq_scale",
"w2_input_scale", "w13_input_scale", "w2_input_offset",
"w13_input_offset", "quant_bias"
]
for param in expected_params:
assert param in quant_params, f"{param} not in {quant_params}"
# Check some sample shapes
self.assertEqual(quant_params["w13_weight_scale"].shape,
(self.num_experts, 2 * self.intermediate_size, 1))
self.assertEqual(quant_params["w2_input_offset"].shape,
(self.num_experts, 1))
self.assertEqual(quant_params["quant_bias"].shape,
(self.num_experts, self.hidden_size))
@patch('vllm_ascend.quantization.w8a8.select_experts')
@patch('vllm_ascend.quantization.w8a8.fused_experts')
def test_apply_with_other_expert_count(self, mock_fused_experts,
mock_select_experts):
# Setup
mock_layer = MagicMock()
x = torch.randn(32, self.hidden_size)
router_logits = torch.randn(32, 128) # 128 experts
top_k = 2
# Mock return values
mock_select_experts.return_value = (torch.randn(32, top_k),
torch.randint(0, 128, (32, top_k)))
mock_fused_experts.return_value = torch.randn(32, self.hidden_size)
# Test
result = self.moe_method.apply(layer=mock_layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=True,
global_num_experts=128)
# Assertions
mock_select_experts.assert_called_once()
mock_fused_experts.assert_called_once()
self.assertEqual(result.shape, (32, self.hidden_size))
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
@patch('vllm_ascend.quantization.w8a8.select_experts')
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
mock_is_310p):
# Setup
mock_layer = MagicMock()
x = torch.randn(32, self.hidden_size)
router_logits = torch.randn(32, 128) # 128 experts
top_k = 2
# Mock return values
mock_select_experts.return_value = (torch.randn(32, top_k),
torch.randint(0, 128, (32, top_k)))
mock_fused_experts_310p.return_value = torch.randn(
32, self.hidden_size)
# Test
result = self.moe_method.apply(layer=mock_layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=True,
global_num_experts=128)
# Assertions
mock_select_experts.assert_called_once()
mock_fused_experts_310p.assert_called_once()
self.assertEqual(result.shape, (32, self.hidden_size))
class TestAscendC8KVCacheMethod(TestBase):
def setUp(self):
self.layer = MagicMock()
self.layer.num_kv_heads = 4
self.layer.head_size = 64
self.layer.num_heads = 8
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.method = AscendC8KVCacheMethod()
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
def test_create_weights(self):
"""测试 create_weights 是否正确注册参数"""
AscendC8KVCacheMethod.create_weights(self.layer)
self.layer.register_parameter.assert_any_call("key_antiquant_scale",
unittest.mock.ANY)
self.layer.register_parameter.assert_any_call("value_antiquant_scale",
unittest.mock.ANY)
calls = self.layer.register_parameter.call_args_list
for call in calls:
args, kwargs = call
param = kwargs.get('parameter', args[1] if len(args) > 1 else None)
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
self.assertEqual(param.shape, expected_shape)
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False)
def test_process_weights_after_loading_not_310p(self, mock_is_310p):
key_data = torch.ones(4 * 64)
value_data = torch.ones(4 * 64) * 2
self.layer.key_antiquant_scale.data = key_data
self.layer.value_antiquant_scale.data = value_data
self.method.process_weights_after_loading(self.layer)
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
def test_process_weights_after_loading_is_310p(self, mock_is_310p):
key_data = torch.ones(4 * 64)
value_data = torch.ones(4 * 64) * 2
self.layer.key_antiquant_scale.data = key_data
self.layer.value_antiquant_scale.data = value_data
self.method.process_weights_after_loading(self.layer)
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
@patch('torch_npu.npu_scatter_nd_update_')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_decode_only(self, mock_quant, mock_scatter):
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
attn_metadata.seq_lens = [10, 10]
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
attn_metadata.slot_mapping = torch.tensor([0, 1])
attn_metadata.attn_mask = None
block_size = 16
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
kv_cache = (key_cache, value_cache)
mock_quant.side_effect = [key, value]
self.layer.key_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.layer.value_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.method.process_weights_after_loading(self.layer)
expected_output = torch.randn(
num_tokens, self.layer.num_heads * self.layer.head_size)
with patch('torch_npu.npu_incre_flash_attention',
return_value=expected_output):
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata,
self.attention_type.DECODER, 1.0,
output)
self.assertEqual(mock_quant.call_count, 2)
self.assertEqual(mock_scatter.call_count, 2)
self.assertTrue(torch.equal(result, expected_output))
@patch('torch_npu.npu_scatter_nd_update_')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_attn_metadata_without_decode(self, mock_quant,
mock_scatter):
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock(spec=[
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
'attn_mask'
])
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
attn_metadata.seq_lens = [10, 10]
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
attn_metadata.slot_mapping = torch.tensor([0, 1])
attn_metadata.attn_mask = None
block_size = 16
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
kv_cache = (key_cache, value_cache)
mock_quant.side_effect = [key, value]
self.layer.key_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.layer.value_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.method.process_weights_after_loading(self.layer)
expected_output = torch.randn(
num_tokens, self.layer.num_heads * self.layer.head_size)
with patch('torch_npu.npu_incre_flash_attention',
return_value=expected_output):
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata,
self.attention_type.DECODER, 1.0,
output)
self.assertEqual(mock_quant.call_count, 2)
self.assertEqual(mock_scatter.call_count, 2)
self.assertTrue(torch.equal(result, expected_output))
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('torch_npu._npu_flash_attention')
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
"""Test apply method in prefill no-cache mode"""
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
attn_metadata.seq_lens = [10, 10]
attn_metadata.attn_mask = torch.ones(2, 2)
kv_cache = (torch.tensor([]), torch.tensor([]))
mock_quant.return_value = key
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata, self.attention_type.DECODER,
1.0, output)
# Check that flash attention was called
mock_flash.assert_called_once()
# Check output shape
self.assertEqual(
result.shape,
(num_tokens, self.layer.num_heads * self.layer.head_size))
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_unsupported_attention_type(self, mock_quant):
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
mock_quant.return_value = key
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
with self.assertRaises(NotImplementedError) as cm:
self.method.apply(self.layer, query, key, value, (None, None),
attn_metadata, self.attention_type.ENCODER, 1.0,
output)
assert "Encoder self-attention" in str(
cm.exception), f"Encoder self-attention not in {str(cm.exception)}"
assert "not implemented" in str(
cm.exception), f"not implemented not in{str(cm.exception)}"
mock_quant.assert_not_called()
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_unsupported_attention_state(self, mock_quant):
"""Test apply with unsupported attention state"""
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit
mock_quant.return_value = key
kv_cache = (torch.tensor([]), torch.tensor([]))
with self.assertRaises(NotImplementedError):
self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata, self.attention_type.DECODER, 1.0,
output)
class TestFusedExperts(TestBase):
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_moe_init_routing_v2')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_moe_finalize_routing')
def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu,
mock_group_matmul,
mock_init_routing,
mock_get_ep_group,
mock_quant_per_tensor):
num_tokens = 32
hidden_size = 128
intermediate_size = 256
num_experts = 4
top_k = 2
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w1_scale = torch.tensor([0.1])
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
w1_input_offset = torch.tensor([0])
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
w2_scale = torch.tensor([0.1])
w2_input_scale = torch.tensor([0.2])
w2_input_offset = torch.tensor([0])
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
expert_map = torch.arange(num_experts)
mock_get_ep_group.return_value.world_size = 8
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_init_routing.return_value = (torch.randn(num_tokens * top_k,
hidden_size),
torch.arange(num_tokens * top_k),
torch.tensor([num_tokens // 2] * 2),
torch.tensor(1.0))
mock_group_matmul.side_effect = [[
torch.randn(num_tokens * top_k, intermediate_size * 2)
], [torch.randn(num_tokens * top_k, hidden_size)]]
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
expected_output = torch.randn(num_tokens, hidden_size)
mock_finalize.return_value = expected_output
output = fused_experts(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w1_input_scale=w1_input_scale,
w1_input_offset=w1_input_offset,
w2=w2,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
w2_input_offset=w2_input_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=expert_map,
)
mock_init_routing.assert_called_once()
self.assertEqual(mock_group_matmul.call_count, 2)
self.assertEqual(output.shape, (num_tokens, hidden_size))
mock_finalize.assert_called_once()
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
def test_fused_experts_without_expert_map(self, mock_swiglu,
mock_group_matmul,
mock_get_ep_group,
mock_quant_per_tensor):
num_tokens = 16
hidden_size = 64
intermediate_size = 128
num_experts = 8
top_k = 1
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
mock_get_ep_group.return_value.world_size = 8
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_group_matmul.side_effect = [[
torch.randn(num_tokens * top_k, intermediate_size * 2)
], [torch.randn(num_tokens * top_k, hidden_size)]]
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
with self.assertRaises(NotImplementedError):
fused_experts(
hidden_states=hidden_states,
w1=w1,
w1_scale=torch.tensor([0.1]),
w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]),
w1_input_offset=torch.tensor([0]),
w2=w2,
w2_scale=torch.tensor([0.1]),
w2_input_scale=torch.tensor([0.1]),
w2_input_offset=torch.tensor([0]),
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=None,
)
class TestFusedExperts310(TestBase):
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_swiglu')
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
mock_get_ep_group,
mock_quant_per_tensor,
mock_matmul_dequant):
num_tokens = 32
hidden_size = 128
intermediate_size = 256
num_experts = 4
top_k = 1
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w1_scale = torch.tensor([0.1])
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
w2_scale = torch.tensor([0.1])
w2_input_scale = torch.tensor([0.2])
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
expert_map = torch.arange(num_experts)
mock_get_ep_group.return_value.world_size = 1
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
mock_matmul_dequant.return_value = hidden_states
output = fused_experts_310p(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w1_input_scale=w1_input_scale,
w2=w2,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=expert_map,
)
self.assertEqual(output.shape, (num_tokens, hidden_size))
self.assertEqual(mock_matmul_dequant.call_count, 2)
class TestSelectExperts(TestBase):
def setUp(self):
# Common test data
self.num_tokens = 10
self.hidden_size = 32
self.num_experts = 8
self.top_k = 2
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
def test_invalid_scoring_func(self):
"""Test invalid scoring function raises ValueError"""
with self.assertRaises(ValueError):
select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func")
@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
"""Test grouped topk functionality"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long))
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
"""Test grouped topk with expert score correction bias"""
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
self.num_experts)
e_score_correction_bias = torch.randn(self.num_experts)
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias)
mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
def test_custom_routing_function(self):
"""Test custom routing function"""
mock_custom_routing = MagicMock()
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32))
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
custom_routing_function=mock_custom_routing)
mock_custom_routing.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_renormalize(self, mock_topk):
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=True,
)
# Check if weights are normalized (sum to 1 for each token)
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
@patch('torch_npu.npu_moe_gating_top_k_softmax')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids, _ = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
)
self.assertEqual(weights.dtype, self.hidden_states.dtype)
self.assertEqual(ids.dtype, torch.int32)
class TestNativeGroupedTopkPartialMock(TestBase):
def test_basic_group_selection(self):
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1],
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
[0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]],
dtype=torch.float32)
expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]])
with patch('torch.topk',
return_value=(None, expected_topk_indices)) as mock_topk:
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
mock_topk.assert_called_once()
expected_result = topk_weights
self.assertTrue(torch.allclose(result, expected_result))
def test_partial_group_selection(self):
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]])
expected_topk_indices = torch.tensor([[0], [1]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
expected_result = torch.tensor(
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]])
self.assertTrue(torch.allclose(result, expected_result))
def test_single_group(self):
topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]])
expected_topk_indices = torch.tensor([[0], [0]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
self.assertTrue(result.numel() > 0)

View File

@@ -0,0 +1,203 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from unittest.mock import patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.sample.rejection_sampler import (
expand_batch_to_tokens, expand_pytorch, rejection_greedy_sample_pytorch,
rejection_random_sample_pytorch, sample_recovered_tokens_pytorch)
# Global constants
PLACEHOLDER_TOKEN_ID = -1
GREEDY_TEMPERATURE = 0.0
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
class TestAscendRejectionSampler(TestBase):
def test_rejection_greedy_sample_pytorch(self):
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
batch_size = 2
max_spec_len = 2
output_token_ids = torch.full((batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID)
cu_num_draft_tokens = torch.tensor([2, 4])
num_draft_tokens = [2, 2]
draft_token_ids = torch.tensor([10, 11, 20, 21])
target_argmax = torch.tensor([10, 99, 20, 22])
bonus_token_ids = torch.tensor([[100], [200]])
is_greedy = torch.tensor([True, True])
rejection_greedy_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
target_argmax,
bonus_token_ids,
num_draft_tokens,
max_spec_len,
is_greedy,
)
assert output_token_ids[0, 0].item() == 10
assert output_token_ids[0, 1].item() == 99
assert output_token_ids[1, 0].item() == 20
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
def test_rejection_random_sample_pytorch(self):
"""Test random rejection sampling: accept based on uniform probability"""
batch_size = 2
max_spec_len = 3
output_token_ids = torch.full((batch_size, max_spec_len + 1),
PLACEHOLDER_TOKEN_ID)
cu_num_draft_tokens = torch.tensor([2, 1])
draft_token_ids = torch.tensor([1, 0, 2])
draft_probs = torch.tensor([
[0.0, 0.6, 0.0, 0.4], # vocab_size=4
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.5, 0.0, 0.0],
])
target_probs = torch.tensor([
[0.0, 0.8, 0.0, 0.2],
[0.2, 0.1, 0.3, 0.4],
[0.9, 0.1, 0.0, 0.0],
])
bonus_token_ids = torch.tensor([[100], [200]])
recovered_token_ids = torch.tensor([1, 2, 3])
uniform_probs = torch.tensor([0.7, 0.6, 0.5])
is_greedy = torch.tensor([False, False])
vocab_size = 4
rejection_random_sample_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
bonus_token_ids,
recovered_token_ids,
uniform_probs,
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=False,
)
assert output_token_ids[0, 0].item() == 1
assert output_token_ids[0, 1].item() == 0
assert output_token_ids[0, 2].item() == 100
def test_expand_pytorch(self):
"""Test expand_pytorch functionality"""
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
cu_num_tokens_ptr = torch.tensor([2, 5, 7])
output_ptr = torch.empty(7, dtype=torch.int32)
expand_pytorch(
output_ptr,
input_ptr,
cu_num_tokens_ptr,
replace_from=0,
replace_to=0,
MAX_NUM_TOKENS=MAX_SPEC_LEN,
)
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
assert torch.equal(output_ptr, expected)
def test_expand_batch_to_tokens(self):
"""Test expand_batch_to_tokens wrapper"""
x = torch.tensor([10, 20, 30])
cu_num_tokens = torch.tensor([2, 5, 7])
num_tokens = 7
with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch"
) as mock_kernel:
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
mock_kernel.assert_called_once()
args = mock_kernel.call_args[0]
assert (args[1] == x).all()
assert (args[2] == cu_num_tokens).all()
# Run actual function
result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
assert torch.equal(result, expected)
def test_sample_recovered_tokens_pytorch_ngram(self):
"""Test recovered token sampling under n-gram mode"""
output_token_ids = torch.empty(2, dtype=torch.int32)
cu_num_draft_tokens = torch.tensor([1, 2])
draft_token_ids = torch.tensor([1, 2])
draft_probs = None
target_probs = torch.tensor([
[0.1, 0.2, 0.7],
[0.3, 0.3, 0.4],
])
q = torch.tensor([
[0.1, 0.2, 0.7],
[0.5, 0.4, 0.1],
])
vocab_size = 3
sample_recovered_tokens_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
IS_NGRAM=True,
)
assert output_token_ids[0].item() == 0
assert output_token_ids[1].item() == 1
def test_sample_recovered_tokens_pytorch_autoregressive(self):
"""Test recovered token sampling for autoregressive models"""
output_token_ids = torch.empty(2, dtype=torch.int32)
cu_num_draft_tokens = torch.tensor([1, 1])
draft_token_ids = torch.tensor([0, 1])
draft_probs = torch.tensor([
[0.6, 0.1, 0.3],
[0.2, 0.7, 0.1],
])
target_probs = torch.tensor([
[0.8, 0.1, 0.1],
[0.3, 0.6, 0.1],
])
q = torch.tensor([
[0.5, 0.3, 0.2],
[0.1, 0.8, 0.1],
])
vocab_size = 3
sample_recovered_tokens_pytorch(
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
vocab_size,
IS_NGRAM=False,
)
assert output_token_ids[0].item() == 0

View File

@@ -0,0 +1,32 @@
from unittest import mock
import torch
from tests.ut.base import TestBase
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
class TestAscendSampler(TestBase):
def test_init_with_raw_logprobs(self):
sampler = AscendSampler(logprobs_mode="raw_logprobs")
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
class TestAscendTopKTopPSampler(TestBase):
@mock.patch("torch_npu.npu_top_k_top_p")
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
mock_npu_op.return_value = (torch.randn(1, 3))
sampler = AscendTopKTopPSampler()
logits = torch.tensor([[1.0, 2.0, 3.0]])
k = torch.tensor([2])
p = torch.tensor([0.9])
generators = {0: torch.Generator()}
generators[0].manual_seed(42)
sampler.forward_native(logits, generators, k, p)
mock_npu_op.assert_called_once_with(logits, p, k)

View File

@@ -0,0 +1,361 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import os
from transformers import PretrainedConfig
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from tests.ut.base import TestBase
from vllm_ascend.ascend_config import (_check_torchair_supported,
check_ascend_config,
clear_ascend_config, get_ascend_config,
init_ascend_config)
class TestAscendConfig(TestBase):
@staticmethod
def _clean_up_ascend_config(func):
def wrapper(*args, **kwargs):
clear_ascend_config()
func(*args, **kwargs)
clear_ascend_config()
return wrapper
@_clean_up_ascend_config
def test_init_ascend_config_without_additional_config(self):
test_vllm_config = VllmConfig()
# No additional config given, check the default value here.
ascend_config = init_ascend_config(test_vllm_config)
self.assertIsNone(ascend_config.expert_map_path)
torchair_graph_config = ascend_config.torchair_graph_config
self.assertFalse(torchair_graph_config.enabled)
self.assertEqual(torchair_graph_config.mode, '')
self.assertFalse(torchair_graph_config.use_cached_graph)
self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertFalse(torchair_graph_config.enable_multistream_mla)
self.assertFalse(torchair_graph_config.enable_multistream_moe)
self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertFalse(torchair_graph_config.enable_kv_nz)
ascend_scheduler_config = ascend_config.ascend_scheduler_config
self.assertFalse(ascend_scheduler_config.enabled)
@_clean_up_ascend_config
def test_init_ascend_config_with_additional_config(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": True,
"graph_batch_sizes": [1, 2, 4],
"graph_batch_sizes_init": False,
"enable_multistream_mla": True,
"enable_multistream_moe": True,
"enable_view_optimize": True,
"enable_kv_nz": True
},
"ascend_scheduler_config": {
"enabled": True
},
"expert_map_path": "test_expert_map_path",
"refresh": True,
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
torchair_graph_config = ascend_config.torchair_graph_config
self.assertTrue(torchair_graph_config.enabled)
self.assertTrue(torchair_graph_config.use_cached_graph)
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
self.assertTrue(torchair_graph_config.enable_multistream_mla)
self.assertTrue(torchair_graph_config.enable_multistream_moe)
self.assertTrue(torchair_graph_config.enable_view_optimize)
self.assertTrue(torchair_graph_config.enable_kv_nz)
ascend_scheduler_config = ascend_config.ascend_scheduler_config
self.assertTrue(ascend_scheduler_config.enabled)
@_clean_up_ascend_config
def test_init_ascend_config_with_refresh(self):
test_vllm_config = VllmConfig()
ascend_config = init_ascend_config(test_vllm_config)
self.assertFalse(ascend_config.torchair_graph_config.enabled)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertFalse(ascend_config.torchair_graph_config.enabled)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True,
}
ascend_config = init_ascend_config(test_vllm_config)
self.assertTrue(ascend_config.torchair_graph_config.enabled)
@_clean_up_ascend_config
def test_init_ascend_config_with_wrong_input(self):
test_vllm_config = VllmConfig()
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"graph_batch_sizes": "fake_size",
},
"refresh": True,
}
with self.assertRaises(TypeError):
init_ascend_config(test_vllm_config)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": True,
},
"refresh": True,
}
with self.assertRaises(ValueError):
init_ascend_config(test_vllm_config)
@_clean_up_ascend_config
def test_get_ascend_config(self):
test_vllm_config = VllmConfig()
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(get_ascend_config(), ascend_config)
@_clean_up_ascend_config
def test_get_ascend_config_without_init(self):
with self.assertRaises(RuntimeError):
get_ascend_config()
@_clean_up_ascend_config
def test_clear_ascend_config(self):
test_vllm_config = VllmConfig()
ascend_config = init_ascend_config(test_vllm_config)
self.assertEqual(get_ascend_config(), ascend_config)
clear_ascend_config()
with self.assertRaises(RuntimeError):
get_ascend_config()
@_clean_up_ascend_config
def test_check_ascend_config_pass(self):
test_vllm_config = VllmConfig()
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
@_clean_up_ascend_config
def test_check_ascend_config_wrong_case(self):
test_vllm_config = VllmConfig()
# torchair + eager mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
enforce_eager = True
check_ascend_config(test_vllm_config, enforce_eager)
# torchair + non deepseek model
with self.assertRaises(NotImplementedError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
},
"refresh": True
}
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
fake_model_config = ModelConfig(model=model_path)
fake_model_config.hf_config = PretrainedConfig()
fake_model_config.hf_config.model_type = "llama"
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
# aclgraph + deepseek model
with self.assertRaises(NotImplementedError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
},
"refresh": True
}
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
fake_model_config = ModelConfig(model=model_path)
fake_model_config.hf_config = PretrainedConfig()
fake_model_config.hf_config.model_type = "deepseek"
test_vllm_config.model_config = fake_model_config
init_ascend_config(test_vllm_config)
check_ascend_config(test_vllm_config, False)
def test_check_torchair_supported(self):
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
('qwen', True), ('llama', False)]
for model_type, expected_output in test_cases:
self.assertEqual(_check_torchair_supported(model_type),
expected_output)
@_clean_up_ascend_config
def test_ascend_config_load_error(self):
test_vllm_config = VllmConfig()
# graph_batch_sizes should be list.
with self.assertRaises(TypeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"graph_batch_sizes": "fake_size",
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_graph should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"use_cached_graph": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_kv_cache_bytes should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"use_cached_kv_cache_bytes": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# graph_batch_sizes should not be set without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes": [1, 2, 4],
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"use_cached_kv_cache_bytes": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# graph_batch_sizes_init should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"graph_batch_sizes_init": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_multistream_mla should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_multistream_mla": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_multistream_moe should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_multistream_moe": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# mode should not be configured without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"mode": 'max-autotune',
},
"refresh": True
}
init_ascend_config(test_vllm_config)
# enable_kv_nz should not be enabled without torchair graph mode
with self.assertRaises(RuntimeError):
test_vllm_config.additional_config = {
"torchair_graph_config": {
"enabled": False,
"enable_kv_nz": True,
},
"refresh": True
}
init_ascend_config(test_vllm_config)
with self.assertRaises(AssertionError):
test_vllm_config.additional_config = {
"lmhead_tensor_parallel_size": 2,
"refresh": True
}
test_vllm_config.parallel_config = ParallelConfig(
data_parallel_size=4, tensor_parallel_size=2)
init_ascend_config(test_vllm_config)

62
tests/ut/test_envs.py Normal file
View File

@@ -0,0 +1,62 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
import inspect
import os
import vllm_ascend.envs as envs_ascend
from tests.ut.base import TestBase
class TestEnvVariables(TestBase):
def setUp(self):
self.env_vars = list(envs_ascend.env_variables.keys())
def test_env_vars_behavior(self):
for var_name in self.env_vars:
with self.subTest(var=var_name):
original_val = os.environ.get(var_name)
var_handler = envs_ascend.env_variables[var_name]
try:
if var_name in os.environ:
del os.environ[var_name]
self.assertEqual(getattr(envs_ascend, var_name),
var_handler())
handler_source = inspect.getsource(var_handler)
if 'int(' in handler_source:
test_vals = ["123", "456"]
elif 'bool(int(' in handler_source:
test_vals = ["0", "1"]
else:
test_vals = [f"test_{var_name}", f"custom_{var_name}"]
for test_val in test_vals:
os.environ[var_name] = test_val
self.assertEqual(getattr(envs_ascend, var_name),
var_handler())
finally:
if original_val is None:
os.environ.pop(var_name, None)
else:
os.environ[var_name] = original_val
def test_dir_and_getattr(self):
self.assertEqual(sorted(envs_ascend.__dir__()), sorted(self.env_vars))
for var_name in self.env_vars:
with self.subTest(var=var_name):
getattr(envs_ascend, var_name)

714
tests/ut/test_platform.py Normal file
View File

@@ -0,0 +1,714 @@
import importlib
import unittest
from datetime import timedelta
from unittest.mock import MagicMock, patch
import pytest
import torch
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import PrefixStore
from vllm.config import CompilationLevel
from vllm.config.compilation import CUDAGraphMode
from vllm.platforms import PlatformEnum
from tests.ut.base import TestBase
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
class TestNPUPlatform(TestBase):
@staticmethod
def mock_vllm_config():
mock_vllm_config = MagicMock()
mock_vllm_config.compilation_config = MagicMock()
mock_vllm_config.model_config = MagicMock()
mock_vllm_config.parallel_config = MagicMock()
mock_vllm_config.cache_config = MagicMock()
mock_vllm_config.scheduler_config = MagicMock()
mock_vllm_config.speculative_config = None
mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
mock_vllm_config.compilation_config.cudagraph_mode = None
return mock_vllm_config
@staticmethod
def mock_vllm_ascend_config():
mock_ascend_config = MagicMock()
mock_ascend_config.torchair_graph_config.enabled = False
mock_ascend_config.ascend_scheduler_config.enabled = False
return mock_ascend_config
def setUp(self):
self.platform = NPUPlatform()
def test_class_variables(self):
self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT)
self.assertEqual(NPUPlatform.device_name, "npu")
self.assertEqual(NPUPlatform.device_type, "npu")
self.assertEqual(NPUPlatform.simple_compile_backend, "eager")
self.assertEqual(NPUPlatform.ray_device_key, "NPU")
self.assertEqual(NPUPlatform.device_control_env_var,
"ASCEND_RT_VISIBLE_DEVICES")
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
self.assertEqual(NPUPlatform.supported_quantization,
[ASCEND_QUANTIZATION_METHOD])
def test_is_sleep_mode_available(self):
self.assertTrue(self.platform.is_sleep_mode_available())
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser(self, mock_quant_config,
mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", "gptq"]
mock_parser._option_string_actions = {"--quantization": mock_action}
self.platform.pre_register_and_update(mock_parser)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
self.assertTrue(ASCEND_QUANTIZATION_METHOD in mock_action.choices)
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_without_parser(self, mock_quant_config,
mock_adapt_patch):
self.platform.pre_register_and_update(None)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_parser_no_quant_action(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_parser._option_string_actions = {}
self.platform.pre_register_and_update(mock_parser)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
@patch("vllm_ascend.utils.adapt_patch")
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
def test_pre_register_and_update_with_existing_ascend_quant(
self, mock_quant_config, mock_adapt_patch):
mock_parser = MagicMock()
mock_action = MagicMock()
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]
mock_parser._option_string_actions = {"--quantization": mock_action}
self.platform.pre_register_and_update(mock_parser)
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
self.assertEqual(len(mock_action.choices), 2)
def test_get_device_capability(self):
self.assertIsNone(self.platform.get_device_capability(device_id=0))
@patch("torch.npu.get_device_name")
def test_get_device_name(self, mock_get_device_name):
device_id = 0
device_name = "Ascend910B2"
mock_get_device_name.return_value = device_name
self.assertEqual(self.platform.get_device_name(device_id), device_name)
mock_get_device_name.assert_called_once_with(0)
def test_is_async_output_supported(self):
self.assertTrue(
self.platform.is_async_output_supported(enforce_eager=None))
self.assertTrue(
self.platform.is_async_output_supported(enforce_eager=True))
self.assertTrue(
self.platform.is_async_output_supported(enforce_eager=False))
@patch("torch.inference_mode")
def test_inference_mode(self, mock_inference_mode):
mock_inference_mode.return_value = None
self.assertIsNone(self.platform.inference_mode())
mock_inference_mode.assert_called_once()
@patch("torch.npu.set_device")
def test_set_device_normal(self, mock_set_device):
device = torch.device("npu:0")
self.platform.set_device(device)
mock_set_device.assert_called_once_with(device)
@patch("torch.npu.set_device",
side_effect=RuntimeError("Device not available"))
def test_set_device_failure(self, mock_set_device):
device = torch.device("npu:0")
with self.assertRaises(RuntimeError):
self.platform.set_device(device)
mock_set_device.assert_called_once_with(device)
@patch("torch.npu.empty_cache")
def test_empty_cache_normal(self, mock_empty_cache):
self.platform.empty_cache()
mock_empty_cache.assert_called_once()
@patch("torch.npu.empty_cache",
side_effect=RuntimeError("Cache clearing failed"))
def test_empty_cache_failure(self, mock_empty_cache):
with self.assertRaises(RuntimeError):
self.platform.empty_cache()
mock_empty_cache.assert_called_once()
@patch("torch.npu.synchronize")
def test_synchronize_normal(self, mock_synchronize):
self.platform.synchronize()
mock_synchronize.assert_called_once()
@patch("torch.npu.synchronize",
side_effect=RuntimeError("Synchronization failed"))
def test_synchronize_failure(self, mock_synchronize):
with self.assertRaises(RuntimeError):
self.platform.synchronize()
mock_synchronize.assert_called_once()
@patch("torch.npu.mem_get_info")
def test_mem_get_info_normal(self, mock_mem_get_info):
free_memory_size = 1024
total_memory_size = 2048
memory_info = (free_memory_size, total_memory_size)
mock_mem_get_info.return_value = memory_info
result = self.platform.mem_get_info()
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 2)
self.assertEqual(result, memory_info)
mock_mem_get_info.assert_called_once()
@patch("torch.npu.mem_get_info",
side_effect=RuntimeError("NPU not available"))
def test_mem_get_info_failure(self, mock_mem_get_info):
with self.assertRaises(RuntimeError):
self.platform.mem_get_info()
mock_mem_get_info.assert_called_once()
@patch("gc.collect")
@patch("torch.npu.empty_cache")
@patch("torch.npu.reset_peak_memory_stats")
def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache,
mock_gc_collect):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_called_once()
mock_reset_stats.assert_called_once()
@patch("gc.collect", side_effect=Exception("GC failed"))
@patch("torch.npu.empty_cache")
@patch("torch.npu.reset_peak_memory_stats")
def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats,
mock_empty_cache,
mock_gc_collect):
with self.assertRaises(Exception):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_not_called()
mock_reset_stats.assert_not_called()
@patch("gc.collect")
@patch("torch.npu.empty_cache",
side_effect=RuntimeError("Cache clear failed"))
@patch("torch.npu.reset_peak_memory_stats")
def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats,
mock_empty_cache,
mock_gc_collect):
with self.assertRaises(RuntimeError):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_called_once()
mock_reset_stats.assert_not_called()
@patch("gc.collect")
@patch("torch.npu.empty_cache")
@patch("torch.npu.reset_peak_memory_stats",
side_effect=RuntimeError("Reset failed"))
def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats,
mock_empty_cache,
mock_gc_collect):
with self.assertRaises(RuntimeError):
self.platform.clear_npu_memory()
mock_gc_collect.assert_called_once()
mock_empty_cache.assert_called_once()
mock_reset_stats.assert_called_once()
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.update_aclgraph_sizes")
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("os.environ", {})
def test_check_and_update_config_basic_config_update(
self, mock_is_310p, mock_update_acl, mock_init_ascend,
mock_check_ascend):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.parallel_config.enable_expert_parallel = False
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
# Without this reload, when calling self.platform.check_and_update_config,
# it would execute the original unmocked init_ascend_config method, causing the unit test to fail.
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
mock_init_ascend.assert_called_once_with(vllm_config)
mock_check_ascend.assert_called_once()
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_no_model_config_warning(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config = None
with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue("Model config is missing" in cm.output[0])
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_enforce_eager_mode(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = True
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue("Compilation disabled, using eager mode by default" in
cm.output[0])
self.assertEqual(
vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_unsupported_compilation_level(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue("NPU does not support" in cm.output[0])
self.assertEqual(
vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@pytest.mark.skip(
"Revert me when vllm support setting cudagraph_mode on oot platform")
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_unsupported_cudagraph_mode(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False
vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue(
"cudagraph_mode is not support on NPU. falling back to NONE" in
cm.output[0])
self.assertEqual(
vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_disable_aclgraph_when_ray_enabled(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
vllm_config.parallel_config.distributed_executor_backend = "ray"
with self.assertLogs(logger="vllm", level="WARNING") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
print(30 * "=", f"cm.output: {cm.output}")
self.assertTrue(
"Ray distributed executor backend is not compatible with ACL Graph mode"
in cm.output[0])
self.assertEqual(
vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_torchair_enabled_compilation(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
mock_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = mock_ascend_config
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.model_config.enforce_eager = False
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
with self.assertLogs(logger="vllm", level="INFO") as cm:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertTrue("Torchair compilation enabled" in cm.output[0])
self.assertEqual(
vllm_config.compilation_config.level,
CompilationLevel.NO_COMPILATION,
)
self.assertEqual(
vllm_config.compilation_config.cudagraph_mode,
CUDAGraphMode.NONE,
)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_cache_config_block_size(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.cache_config.block_size = None
vllm_config.cache_config.enable_prefix_caching = True
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertEqual(vllm_config.cache_config.block_size, 128)
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_v1_worker_class_selection(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.parallel_config.worker_cls = "auto"
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertEqual(
vllm_config.parallel_config.worker_cls,
"vllm_ascend.worker.worker_v1.NPUWorker",
)
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
test_ascend_config.torchair_graph_config.enabled = True
mock_init_ascend.return_value = test_ascend_config
vllm_config.parallel_config.worker_cls = "auto"
self.platform.check_and_update_config(vllm_config)
self.assertEqual(
vllm_config.parallel_config.worker_cls,
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
@patch("vllm_ascend.utils.is_310p", return_value=True)
def test_check_and_update_config_310p_no_custom_ops(
self, mock_is_310p, mock_init_ascend, mock_check_ascend):
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
)
vllm_config = TestNPUPlatform.mock_vllm_config()
vllm_config.compilation_config.custom_ops = []
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
self.assertEqual(vllm_config.compilation_config.custom_ops, [])
@patch("vllm_ascend.utils.is_310p", return_value=False)
@patch("vllm_ascend.ascend_config.check_ascend_config")
@patch("vllm_ascend.ascend_config.init_ascend_config")
def test_check_and_update_config_ascend_scheduler_config(
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
mock_ascend_config.ascend_scheduler_config.enabled = True
mock_init_ascend.return_value = mock_ascend_config
vllm_config = TestNPUPlatform.mock_vllm_config()
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
) as mock_scheduler:
from vllm_ascend import platform
importlib.reload(platform)
self.platform.check_and_update_config(vllm_config)
mock_scheduler.initialize_from_config.assert_called_once()
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=True,
)
self.assertEqual(result,
"vllm_ascend.attention.mla_v1.AscendMLABackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=True,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_and_torchair(self,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
)
@patch('vllm_ascend.platform.get_ascend_config')
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
result = self.platform.get_attn_backend_cls(
selected_backend="ascend",
head_size=64,
dtype="float16",
kv_cache_dtype="float16",
block_size=64,
use_v1=True,
use_mla=False,
)
self.assertEqual(
result,
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
def test_get_punica_wrapper(self):
result = self.platform.get_punica_wrapper()
self.assertEqual(
result,
"vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU")
@patch("torch.npu.reset_peak_memory_stats")
@patch("torch.npu.max_memory_allocated")
def test_get_current_memory_usage_with_specific_device(
self, mock_max_memory, mock_reset_stats):
max_memory_allocated_result = 1024.0
mock_max_memory.return_value = max_memory_allocated_result
test_device = torch.device("npu:0")
result = self.platform.get_current_memory_usage(device=test_device)
mock_reset_stats.assert_called_once_with(test_device)
mock_max_memory.assert_called_once_with(test_device)
self.assertEqual(result, max_memory_allocated_result)
@patch("torch.npu.reset_peak_memory_stats")
@patch("torch.npu.max_memory_allocated")
def test_get_current_memory_usage_with_default_device(
self, mock_max_memory, mock_reset_stats):
max_memory_allocated_result = 1024.0
mock_max_memory.return_value = max_memory_allocated_result
result = self.platform.get_current_memory_usage()
mock_reset_stats.assert_called_once_with(None)
mock_max_memory.assert_called_once_with(None)
self.assertEqual(result, max_memory_allocated_result)
@patch("torch.npu.reset_peak_memory_stats",
side_effect=RuntimeError("Device error"))
@patch("torch.npu.max_memory_allocated")
def test_get_current_memory_usage_when_reset_stats_fails(
self, mock_max_memory, mock_reset_stats):
with self.assertRaises(RuntimeError):
self.platform.get_current_memory_usage()
mock_reset_stats.assert_called_once()
mock_max_memory.assert_not_called()
@patch("torch.npu.reset_peak_memory_stats")
@patch(
"torch.npu.max_memory_allocated",
side_effect=RuntimeError("Memory query failed"),
)
def test_get_current_memory_usage_when_query_fails(self, mock_max_memory,
mock_reset_stats):
with self.assertRaises(RuntimeError):
self.platform.get_current_memory_usage()
mock_reset_stats.assert_called_once()
mock_max_memory.assert_called_once()
def test_get_device_communicator_cls_returns_correct_value(self):
self.assertEqual(
self.platform.get_device_communicator_cls(),
"vllm_ascend.distributed.communicator.NPUCommunicator",
)
def test_is_pin_memory_available_returns_true(self):
self.assertTrue(self.platform.is_pin_memory_available())
def test_supports_v1(self):
from vllm.config import ModelConfig
mock_config = MagicMock(spec=ModelConfig)
self.assertTrue(self.platform.supports_v1(mock_config))
def test_get_static_graph_wrapper_cls_returns_correct_value(self):
self.assertEqual(
self.platform.get_static_graph_wrapper_cls(),
"vllm_ascend.compilation.acl_graph.ACLGraphWrapper",
)
@patch("torch.distributed.is_hccl_available", return_value=True)
@patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL")
@patch("torch.distributed.ProcessGroup")
def test_successful_initialization(self, mock_pg, mock_pg_hccl, _):
mock_prefix = MagicMock(spec=PrefixStore)
mock_backend = MagicMock()
mock_pg_hccl.return_value = mock_backend
group_rank = 0
group_size = 4
mock_pg_instance = MagicMock(spec=ProcessGroup)
mock_pg.return_value = mock_pg_instance
# Use importlib.reload() to force-reload the platform module and ensure the mocked ProcessGroup is used.
# Without this reload, when executing self.platform.stateless_init_device_torch_dist_pg(),
# it would invoke the original unmocked ProcessGroup implementation instead of our test mock,
# which would cause the unit test to fail.
from vllm_ascend import platform
importlib.reload(platform)
result = self.platform.stateless_init_device_torch_dist_pg(
backend="hccl",
prefix_store=mock_prefix,
group_rank=group_rank,
group_size=group_size,
timeout=timedelta(seconds=30),
)
mock_pg.assert_called_once_with(mock_prefix, group_rank, group_size)
mock_pg_hccl.assert_called_once_with(mock_prefix, group_rank,
group_size, unittest.mock.ANY)
mock_backend._set_sequence_number_for_group.assert_called_once()
mock_pg_instance._register_backend.assert_called_once_with(
torch.device("npu"), unittest.mock.ANY, mock_backend)
self.assertEqual(result, mock_pg_instance)
@patch("torch.distributed.is_hccl_available", return_value=False)
def test_hccl_unavailable(self, _):
with self.assertRaises(AssertionError):
from vllm_ascend import platform
importlib.reload(platform)
self.platform.stateless_init_device_torch_dist_pg(
backend="hccl",
prefix_store=MagicMock(),
group_rank=0,
group_size=4,
timeout=timedelta(seconds=30),
)

351
tests/ut/test_utils.py Normal file
View File

@@ -0,0 +1,351 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import math
import os
from threading import Lock
from unittest import mock
import torch
from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig,
VllmConfig)
from tests.ut.base import TestBase
from vllm_ascend import utils
class TestUtils(TestBase):
def test_is_310p(self):
utils._IS_310P = None
with mock.patch("vllm_ascend._build_info.__soc_version__",
"Ascend310P3"):
self.assertTrue(utils.is_310p())
utils._IS_310P = None
with mock.patch("vllm_ascend._build_info.__soc_version__",
"Ascend910P1"):
self.assertFalse(utils.is_310p())
def test_sleep_mode_enabled(self):
utils._SLEEP_MODE_ENABLED = None
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
True):
self.assertTrue(utils.sleep_mode_enabled())
utils._SLEEP_MODE_ENABLED = None
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
False):
self.assertFalse(utils.sleep_mode_enabled())
def test_nd_to_nz_2d(self):
# can be divided by 16
input_tensor = torch.randn(32, 64)
output = utils.nd_to_nz_2d(input_tensor)
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 64 // 16)
self.assertEqual(output.shape[2], 32)
self.assertEqual(output.shape[3], 16)
# cannot be divided by 16
input_tensor = torch.randn(30, 62)
output = utils.nd_to_nz_2d(input_tensor)
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], math.ceil(62 / 16))
self.assertEqual(output.shape[2], 32)
self.assertEqual(output.shape[3], 16)
# pad to 16
input_tensor = torch.randn(8, 12)
output = utils.nd_to_nz_2d(input_tensor)
self.assertEqual(output.shape[0], 1)
self.assertEqual(output.shape[1], 1) # 12->16, 16//16=1
self.assertEqual(output.shape[2], 16) # 8->16
self.assertEqual(output.shape[3], 16)
# check if the output is contiguous
input_tensor = torch.randn(32, 64)
output = utils.nd_to_nz_2d(input_tensor)
self.assertTrue(output.is_contiguous())
# check if the output values are preserved
input_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
output = utils.nd_to_nz_2d(input_tensor)
expected = torch.tensor(
[[[[1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]])
self.assertTrue(torch.allclose(output, expected))
def test_aligned_16(self):
# align to 16
input_tensor = torch.randn(15, 64)
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 16)
# align to 16
input_tensor = torch.randn(16, 64)
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 16)
self.assertTrue(torch.equal(input_tensor, output_tensor))
# align to 32
input_tensor = torch.randn(17, 64)
output_tensor = utils.aligned_16(input_tensor)
self.assertEqual(output_tensor.shape[0], 32)
@mock.patch('importlib.util.find_spec')
@mock.patch('importlib.import_module')
def test_try_register_lib(self, mock_import_module, mock_find_spec):
# import OK
mock_find_spec.return_value = mock.MagicMock()
mock_import_module.return_value = mock.MagicMock()
lib_name = "existing_lib"
lib_info = "Library found and imported successfully"
utils.try_register_lib(lib_name, lib_info)
# Can't find lib
mock_find_spec.return_value = None
lib_name = "non_existing_lib"
utils.try_register_lib(lib_name)
# import error
mock_find_spec.return_value = mock.MagicMock()
mock_import_module.side_effect = ImportError("import error")
lib_name = "error_lib"
utils.try_register_lib(lib_name)
def test_enable_custom_op(self):
result = utils.enable_custom_op()
self.assertTrue(result)
utils._CUSTOM_OP_ENABLED = None
with mock.patch('builtins.__import__') as mock_import_module:
mock_import_module.side_effect = ImportError("import error")
self.assertFalse(utils.enable_custom_op())
def test_find_hccl_library(self):
with mock.patch.dict(os.environ,
{"HCCL_SO_PATH": "/path/to/hccl/libhccl.so"}):
self.assertEqual(utils.find_hccl_library(),
"/path/to/hccl/libhccl.so")
with mock.patch("torch.version.cann", None):
self.assertRaises(ValueError, utils.find_hccl_library)
with mock.patch("torch.version.cann", "Ascend910"):
self.assertEqual(utils.find_hccl_library(), "libhccl.so")
def test_current_stream(self):
with mock.patch("torch.npu.current_stream") as mock_current_stream:
self.assertEqual(utils.current_stream(), mock_current_stream())
def test_vllm_version_is(self):
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
with mock.patch("vllm.__version__", "1.0.0"):
self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0"))
self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0"))
with mock.patch("vllm.__version__", "2.0.0"):
self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0"))
self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0"))
with mock.patch("vllm.__version__", "1.0.0"):
self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0"))
self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0"))
with mock.patch("vllm.__version__", "2.0.0"):
self.assertTrue(utils.vllm_version_is.__wrapped__("2.0.0"))
self.assertFalse(utils.vllm_version_is.__wrapped__("1.0.0"))
# Test caching takes effect
utils.vllm_version_is.cache_clear()
utils.vllm_version_is("1.0.0")
misses = utils.vllm_version_is.cache_info().misses
hits = utils.vllm_version_is.cache_info().hits
self.assertEqual(misses, 1)
self.assertEqual(hits, 0)
utils.vllm_version_is("1.0.0")
hits = utils.vllm_version_is.cache_info().hits
self.assertEqual(hits, 1)
def test_get_max_hidden_layers(self):
from transformers import PretrainedConfig
class SimpleConfig(PretrainedConfig):
def __init__(self, num_hidden_layers=12):
self.num_hidden_layers = num_hidden_layers
def to_dict(self):
return {"num_hidden_layers": self.num_hidden_layers}
self.assertEqual(utils.get_max_hidden_layers(SimpleConfig()), 12)
self.assertEqual(utils.get_max_hidden_layers(SimpleConfig(24)), 24)
class NestedConfig(PretrainedConfig):
def to_dict(self):
return {
"model": {
"encoder": {
"num_hidden_layers": 8
},
"decoder": {
"num_hidden_layers": 12
}
},
"other_setting": True
}
self.assertEqual(utils.get_max_hidden_layers(NestedConfig()), 12)
class MultiValueConfig(PretrainedConfig):
def to_dict(self):
return {
"num_hidden_layers": 6,
"submodule": {
"num_hidden_layers": 18,
"subsub": {
"num_hidden_layers": 9
}
}
}
self.assertEqual(utils.get_max_hidden_layers(MultiValueConfig()), 18)
class NoLayerConfig(PretrainedConfig):
def to_dict(self):
return {"attention_heads": 8}
with self.assertRaises(ValueError) as context:
utils.get_max_hidden_layers(NoLayerConfig())
self.assertIn("num_hidden_layers", str(context.exception))
def test_update_aclgraph_sizes(self):
# max_num_batch_sizes < len(original_sizes)
test_compilation_config = CompilationConfig(
cudagraph_capture_sizes=[i for i in range(150)])
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
test_model_config = ModelConfig(model=model_path, enforce_eager=True)
test_parallel_config = ParallelConfig()
test_vllm_config = VllmConfig(
model_config=test_model_config,
compilation_config=test_compilation_config,
parallel_config=test_parallel_config,
)
utils.update_aclgraph_sizes(test_vllm_config)
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
utils.update_aclgraph_sizes(test_vllm_config)
del os.environ['HCCL_OP_EXPANSION_MODE']
self.assertEqual(
147,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
# max_num_batch_sizes >= len(original_sizes)
test_compilation_config = CompilationConfig(
cudagraph_capture_sizes=[1, 2, 3])
test_vllm_config = VllmConfig(
model_config=test_model_config,
compilation_config=test_compilation_config,
parallel_config=test_parallel_config,
)
utils.update_aclgraph_sizes(test_vllm_config)
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
utils.update_aclgraph_sizes(test_vllm_config)
del os.environ['HCCL_OP_EXPANSION_MODE']
self.assertEqual(
3,
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
@mock.patch("vllm.model_executor.custom_op.CustomOp")
@mock.patch("vllm_ascend.ops.activation.AscendQuickGELU")
@mock.patch("vllm_ascend.ops.activation.AscendSiluAndMul")
@mock.patch("vllm_ascend.ops.layernorm.AscendRMSNorm")
def test_register_ascend_customop(self, mock_ascend_rmsnorm,
mock_ascend_silu_and_mul,
mock_ascend_quick_gelu, mock_customop):
utils._ASCEND_CUSTOMOP_IS_REIGISTERED = False
# ascend custom op is not registered
utils.register_ascend_customop()
# should call register_oot three
self.assertEqual(mock_customop.register_oot.call_count, 12)
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
# ascend custom op is already registered
utils.register_ascend_customop()
# should not register_oot again, thus only called three in this ut
self.assertEqual(mock_customop.register_oot.call_count, 12)
class TestProfileExecuteDuration(TestBase):
def setUp(self):
utils.ProfileExecuteDuration._instance = None
utils.ProfileExecuteDuration._observations = []
utils.ProfileExecuteDuration._lock = Lock()
def test_singleton_creation(self):
instance1 = utils.ProfileExecuteDuration()
self.assertIsNotNone(instance1)
self.assertIs(instance1, utils.ProfileExecuteDuration._instance)
instance2 = utils.ProfileExecuteDuration()
self.assertIs(instance1, instance2)
def test_thread_safety(self):
from threading import Thread
instances = []
def create_instance():
instances.append(utils.ProfileExecuteDuration())
threads = [Thread(target=create_instance) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
first_instance = instances[0]
for instance in instances[1:]:
self.assertIs(first_instance, instance)
def test_atexit_registration(self):
with mock.patch('atexit.register') as mock_register:
instance = utils.ProfileExecuteDuration()
mock_register.assert_called_once_with(instance.destroy)
def test_lock_usage(self):
original_lock = utils.ProfileExecuteDuration._lock
with mock.patch.object(utils.ProfileExecuteDuration,
'_lock',
wraps=original_lock) as mock_lock:
utils.ProfileExecuteDuration()
mock_lock.__enter__.assert_called()
mock_lock.__exit__.assert_called()
def test_observations_initialization(self):
instance = utils.ProfileExecuteDuration()
self.assertEqual(instance._observations, [])

View File

View File

@@ -0,0 +1,195 @@
import pytest
import torch
from pytest_mock import MockerFixture
from transformers import PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from tests.ut.base import PytestBase
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
TorchairDeepSeekMultiTokenPredictorLayer)
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
@pytest.fixture
def setup_mtp_layer(self, mocker: MockerFixture):
config = PretrainedConfig(vocab_size=1000,
hidden_size=768,
rms_norm_eps=1e-5)
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
return_value=None)
mocker.patch(
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
return_value=None)
mocker_deepseek_v2_decode_layer = mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
mocker_deepseek_v2_decode_layer.assert_called_once()
return mtp_layer
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
mtp_layer = setup_mtp_layer
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch.object(mtp_layer,
'eh_proj',
return_value=torch.randn(2, 3, 768))
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
torch.randn(2, 3, 768))
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
kv_cache = torch.randn(2, 3, 768)
previous_hidden_states = torch.randn(2, 3, 768)
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
output = mtp_layer(input_ids, positions, kv_cache, None,
previous_hidden_states, inputs_embeds, 0)
assert output.shape == (2, 3, 768)
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
@pytest.fixture
def setup_predictor(self, mocker: MockerFixture):
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
mock_model_config = mocker.MagicMock(spec=ModelConfig)
mock_hf_config = mocker.MagicMock()
mock_hf_config.num_hidden_layers = 12
mock_hf_config.num_nextn_predict_layers = 3
mock_hf_config.vocab_size = 30000
mock_model_config.hf_config = mock_hf_config
mock_vllm_config.model_config = mock_model_config
mock_vllm_config.cache_config = CacheConfig()
mock_vllm_config.quant_config = mocker.MagicMock()
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
predictor = TorchairDeepSeekMultiTokenPredictor(
vllm_config=mock_vllm_config)
return predictor
def test_init(self, mocker: MockerFixture, setup_predictor):
predictor = setup_predictor
assert predictor.num_mtp_layers == 3
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
@pytest.mark.parametrize(
'kv_caches, inputs_embeds',
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
inputs_embeds):
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
# todo: need or not?
# predictor.num_mtp_layers = 1
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
output = predictor.forward(input_ids, positions, kv_caches, None, None,
inputs_embeds, 0)
mock_layer.assert_called_once()
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
predictor = setup_predictor
mock_layer = mocker.MagicMock()
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
predictor.layers_list = [mock_layer]
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
return_value=None)
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
result_logits = predictor.compute_logits(hidden_states=hidden_states,
sampling_metadata=None)
predictor.logits_processor.assert_called_once()
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
class TestTorchairDeepSeekMTP(PytestBase):
@pytest.fixture
def setup_mtp(self, mocker: MockerFixture):
vllm_config = mocker.MagicMock()
vllm_config.model_config.hf_config.num_hidden_layers = 12
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
vllm_config.cache_config = mocker.MagicMock()
vllm_config.quant_config = mocker.MagicMock()
mocker.patch("torch.nn.Module.__setattr__")
mocker.patch("torch.nn.Module.__getattr__")
mocker.patch("torch.nn.Module.__delattr__")
mocker.patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
return_value=None)
mocker.patch(
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
return_value=None)
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
return_value=None)
mocker.patch(
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
return_value=None)
mocker.patch("vllm_ascend.utils.get_ascend_config",
return_value=mocker.Mock())
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
return mtp
def test_init(self, mocker: MockerFixture, setup_mtp):
mtp = setup_mtp
assert isinstance(mtp, TorchairDeepSeekMTP)
def test_forward(self, mocker: MockerFixture, setup_mtp):
input_ids = torch.tensor([[1, 2, 3]])
positions = torch.tensor([[0, 1, 2]])
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
spec_step_idx = 0
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
previous_hidden_states, inputs_embeds,
spec_step_idx)
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))

View File

@@ -0,0 +1,325 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from types import SimpleNamespace
from unittest.mock import Mock, patch
import pytest
import torch
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed.parallel_state import GroupCoordinator
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
TorchairDeepseekV2RowParallelLinear,
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
TorchairDeepseekV2SiluAndMul)
@pytest.fixture
def base_config():
config = PretrainedConfig(
hidden_size=128,
num_attention_heads=8,
num_hidden_layers=2,
intermediate_size=256,
hidden_act="silu",
rms_norm_eps=1e-6,
rope_theta=10000.0,
max_position_embeddings=2048,
n_routed_experts=4,
n_shared_experts=1,
moe_intermediate_size=256,
num_experts_per_tok=2,
routed_scaling_factor=1.0,
first_k_dense_replace=0,
moe_layer_freq=1,
kv_lora_rank=16,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
topk_method="noaux_tc",
scoring_func="softmax",
norm_topk_prob=True,
n_group=1,
topk_group=1,
vocab_size=10000,
)
return config
@pytest.fixture
def vllm_config(base_config):
model_config = SimpleNamespace(
hf_config=base_config,
tensor_parallel_size=1,
dtype=torch.float32,
use_mla=False,
quant_config=None,
max_model_len=2048,
)
cache_config = CacheConfig()
vllm_config = Mock()
vllm_config.model_config = model_config
vllm_config.cache_config = cache_config
vllm_config.quant_config = None
return vllm_config
@pytest.fixture
def mock_distributed():
tp_group = Mock(spec=GroupCoordinator)
tp_group.rank_in_group = 0
tp_group.world_size = 1
tp_group.device_group = Mock()
dp_group = Mock(spec=GroupCoordinator)
dp_group.rank_in_group = 0
dp_group.world_size = 1
ep_group = Mock(spec=GroupCoordinator)
ep_group.rank_in_group = 0
ep_group.world_size = 1
pp_group = Mock(spec=GroupCoordinator)
pp_group.rank_in_group = 0
pp_group.world_size = 1
mock_vllm_config = Mock()
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
_PP=pp_group), \
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
yield
@pytest.fixture
def mock_forward_context():
forward_context = Mock(in_profile_run=False, with_prefill=False)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context",
return_value=forward_context):
yield
def test_torchair_deepseek_v2_silu_and_mul():
torch.set_default_device("cpu")
silu = TorchairDeepseekV2SiluAndMul()
assert silu.weight_scale is None
x = torch.randn(2, 4)
output = silu.forward_oot(x)
assert output.shape == (2, 2)
weight_scale = Mock(return_value=torch.tensor(0.1))
silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale)
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
dynamic_scale = torch.randn(2, 1)
with patch("torch_npu.npu_dequant_swiglu_quant",
return_value=torch.randn(2, 4)):
output = silu.forward_oot((quant_x, dynamic_scale))
assert output.shape == (2, 4)
def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128,
output_sizes=[64, 64],
bias=False,
quant_config=None)
assert linear.output_sizes == [64, 64]
param = Mock()
param.data = torch.zeros(128, 128)
param.output_dim = 1
param.is_gguf_weight = False
param.is_gguf_weight_type = False
loaded_weight = torch.randn(128, 64)
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
with pytest.raises(AssertionError):
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
@pytest.mark.parametrize("cls", [
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
TorchairDeepseekV2RowParallelLinear
])
def test_row_parallel_linear(cls, mock_distributed):
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
linear.quant_method = Mock()
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
input_ = torch.randn(2, 4, 128)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim",
return_value=[torch.randn(2, 4, 64)]):
linear.input_is_parallel = False
output = linear(input_, is_prefill=True)
assert output[0].shape == (2, 4, 64)
linear.input_is_parallel = True
output = linear(input_, is_prefill=False)
assert output[0].shape == (2, 4, 64)
def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
mlp = TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=None)
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
x = torch.randn(2, 4, 128)
output = mlp(x)
assert output.shape == (2, 4, 128)
with patch(
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
) as mock_quant_config:
mock_quant_config.name = "w8a8dynamic"
with pytest.raises(NotImplementedError):
TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="silu",
quant_config=mock_quant_config,
force_replicate=False)
with pytest.raises(ValueError):
TorchairDeepseekV2MLP(hidden_size=128,
intermediate_size=256,
hidden_act="relu",
quant_config=None)
def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
mock_forward_context):
base_config.n_shared_experts = 1
moe = TorchairDeepseekV2MoE(config=base_config,
quant_config=None,
prefix="mlp")
assert moe.top_k == 2
x = torch.randn(2, 4, 128)
attn_metadata = Mock(num_prefills=1)
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
output = moe(x, attn_metadata)
assert output.shape == (2, 4, 128)
@patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=16,
kv_lora_rank=16,
cache_config=CacheConfig(),
quant_config=None,
prefix="layers.0.self_attn")
assert attn.debug_layer_idx == 0
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(attn.mla_attn,
"__call__",
return_value=torch.randn(2, 4, 128)):
with pytest.raises(AssertionError):
attn(positions, x)
attn = TorchairDeepseekV2MLAAttention(config=base_config,
hidden_size=128,
num_heads=8,
qk_nope_head_dim=16,
qk_rope_head_dim=16,
v_head_dim=32,
q_lora_rank=None,
kv_lora_rank=16,
prefix="layers.1.self_attn")
assert hasattr(attn, "q_proj")
@patch("torch_npu.npu_add_rms_norm")
@patch("torch_npu.npu_rms_norm")
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
mock_distributed, base_config,
vllm_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
torch.randn(2, 128))
base_config.n_routed_experts = 4
layer = TorchairDeepseekV2DecoderLayer(
config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
cache_config=CacheConfig(),
quant_config=None)
assert isinstance(layer.mlp, TorchairDeepseekV2MoE)
x = torch.randn(2, 4, 128)
positions = torch.arange(4).repeat(2, 1)
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
hidden_states, residual = layer(positions, x, None)
assert hidden_states.shape == (2, 4, 128)
base_config.n_routed_experts = None
layer = TorchairDeepseekV2DecoderLayer(
config=base_config,
prefix="layers.0",
model_config=vllm_config.model_config,
quant_config=None)
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
input_ids = torch.randint(0, 10000, (2, 4))
positions = torch.arange(4).repeat(2, 1)
with patch.object(model.model,
"forward",
return_value=torch.randn(2, 4, 128)):
output = model(input_ids, positions)
assert output.shape == (2, 4, 128)
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
with patch(
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
):
loaded = model.load_weights(weights)
assert loaded is not None

View File

@@ -0,0 +1,410 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from typing import List, TypedDict
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.quantization.quantizer import W8A8Quantizer
from vllm_ascend.torchair.ops.torchair_fused_moe import (
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
adapt_patch(True)
def mock_ep_and_mc2_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.rank = 0
mock_group.world_size = 4
mock_group.device_group = "mock_group_ep"
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
return mock_group
def mock_dp_and_tp_group(mocker):
mock_group = mocker.MagicMock()
mock_group.rank_in_group = 0
mock_group.world_size = 2
mock_group.device_group = "mock_group"
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
return mock_group
@pytest.fixture
def mock_dist_env(mocker: MockerFixture):
# init dist env patch
with patch('torch.distributed.get_rank', return_value=0), \
patch('torch.distributed.get_world_size', return_value=4), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
return_value=torch.randn(5, 32)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.data_parallel_reduce_scatter',
return_value=torch.randn(5, 32)), \
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
return_value=MagicMock(
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
expert_map_path=None
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
return_value=MagicMock(
max_tokens_across_dp=10,
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
)), \
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
return_value=MagicMock(
parallel_config=MagicMock(tensor_parallel_size=2),
scheduler_config=MagicMock(max_num_seqs=4),
model_config=MagicMock(max_model_len=2048)
)):
yield
@pytest.fixture
def mock_moe_env(mocker: MockerFixture):
# init moe env patch
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
None
)), \
patch('torch_npu.npu_moe_init_routing', return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
torch.randn(8, 2)
)), \
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_distribute_combine", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_grouped_matmul", return_value=(
[torch.randn(16, 2)]
)), \
patch("torch_npu.npu_swiglu", return_value=(
torch.randn(16, 2)
)), \
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
torch.randn(8, 2),
torch.randint(0, 8, (8, 2)),
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
)), \
patch("torch_npu.npu_moe_finalize_routing", return_value=(
torch.randn(16, 2)
)):
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
torch.randn(16, 2))), \
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
torch.randn(16, 2))):
yield
else:
yield
@pytest.fixture
def default_moe_config():
"""default moe config"""
return {
'num_experts': 8,
'top_k': 2,
'hidden_size': 512,
'intermediate_size': 1024
}
@pytest.fixture
def moe_method(mock_dist_env):
moe = MagicMock()
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
return TorchairAscendUnquantizedFusedMoEMethod(moe)
class Device(TypedDict):
device_id: int
device_expert: List[int]
class Layer(TypedDict):
layer_id: int
device_count: int
device_list: List[Device]
class MockData(TypedDict):
moe_layer_count: int
layer_list: List[Layer]
class MockQuantMethod(nn.Module):
def __init__(self, shared_experts, num_tokens):
super().__init__()
if shared_experts:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
torch.randn(num_tokens, 10)))
else:
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
class MockFusedMoEMethod(FusedMoEMethodBase):
moe = MagicMock()
def __init__(self):
super().__init__(self.moe)
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass
def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass
class TestTorchairAscendFusedMoe:
def test_init_no_quant(self, mock_dist_env, default_moe_config):
layer = TorchairAscendFusedMoE(**default_moe_config)
layer.w13_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['intermediate_size'] * 2,
default_moe_config['hidden_size']))
layer.w2_weight = nn.Parameter(
torch.randn(default_moe_config['num_experts'],
default_moe_config['hidden_size'],
default_moe_config['intermediate_size']))
assert layer.num_experts == default_moe_config['num_experts']
assert layer.top_k == default_moe_config['top_k']
assert hasattr(layer, 'w13_weight')
assert hasattr(layer, 'w2_weight')
# check group_topk
with pytest.raises(AssertionError):
error_config = default_moe_config.copy()
error_config['use_grouped_topk'] = True
layer = TorchairAscendFusedMoE(**error_config)
# check scoring_func
with pytest.raises(ValueError):
error_config = default_moe_config.copy()
error_config['scoring_func'] = "random"
layer = TorchairAscendFusedMoE(**error_config)
def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = False
with patch(
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
return_value=W8A8Quantizer):
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = True
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method,
TorchairAscendUnquantizedFusedMoEMethod)
@pytest.mark.parametrize(
"others_param",
[[None,
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
[2, None, False, 5, None], [None, None, True, 5, None],
[None, None, False, 1, None], [None, None, True, 5, 1],
[None, None, False, 5, 1]])
def test_forward(self, mock_dist_env, default_moe_config, others_param):
"""
1 test has shared_experts
2 test has top_k
3 test is_prefill is true
4 test single num_tokens(decode)
5 test ep_size is 1 and is_prefill is true
6 test ep_size is 1 and is_prefill is False
"""
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
inputs = torch.randn(num_tokens, 32)
router_logits = torch.randn(num_tokens, 8)
moe = TorchairAscendFusedMoE(**default_moe_config)
if ep_size == 1:
moe.moe_parallel_config.ep_size = 1
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
dtype=torch.bool),
padded_num_tokens=num_tokens)
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
return_value=forward_context):
output = moe.forward(inputs,
router_logits,
is_prefill=is_prefill,
top_k=top_k,
shared_experts=shared_experts)
moe.quant_method.apply.assert_called_once()
if shared_experts:
assert output[0].shape == (num_tokens, 32)
assert output[1].shape == (num_tokens, 10)
else:
assert output.shape == (num_tokens, 32)
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
default_moe_config):
inputs = torch.randn(5, 32)
router_logits = torch.randn(5, 8)
moe = TorchairAscendFusedMoE(**default_moe_config)
moe.quant_method = MockQuantMethod(None, 5)
output = moe._forward_ms_fused_moe_comp(inputs,
router_logits,
is_prefill=False,
real_top_k=1)
moe.quant_method.apply.assert_called_once()
assert output.shape == (5, 32)
class TestTorchairAscendUnquantizedFusedMoEMethod:
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
layer = MagicMock()
layer.w13_weight.data = torch.randn(16, 32)
layer.w2_weight.data = torch.randn(16, 32)
moe_method.process_weights_after_loading(layer)
assert isinstance(layer.w13_weight, torch.nn.Parameter)
assert isinstance(layer.w2_weight, torch.nn.Parameter)
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
2 test use_select_experts and fused_experts
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch(
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=global_num_experts,
is_prefill=is_prefill)
if ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape
@pytest.mark.parametrize("others_param", [16, 1, 4])
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
1 test use_select_experts and use fused_expters_with_mc2
2 test use_select_experts and fused_experts_with_all2all_buffer
3 test use_select_experts and fused_experts_with_all2all
4 test use_select_experts and fused_experts
"""
ep_size = others_param
is_prefill = False
forward_context = MagicMock(
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
if ep_size == 1:
x = x.view(-1, 2)
router_logits = torch.randn(8, 8)
layer = MagicMock()
layer.w13_weight = torch.randn(8, 16, 1)
layer.w2_weight = torch.randn(16, 8, 1)
result = moe_method.apply(layer=layer,
x=x,
router_logits=router_logits,
top_k=2,
renormalize=True,
global_num_experts=128,
expert_map=expert_map,
is_prefill=is_prefill)
if ep_size == 16 or ep_size == 1:
assert result.shape == (16, 2)
else:
assert result.shape == x.shape

View File

@@ -0,0 +1,332 @@
import math
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
class TestCustomRotaryEmbeddingEnabled(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
def test_custom_rotary_embedding_enabled(self):
# Test when all conditions are True
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertTrue(result)
# Test when dtype is not float16
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
query = self.query.to(torch.float32)
result = custom_rotary_embedding_enabled(query, True,
self.head_size)
self.assertFalse(result)
# Test when neox_style is False
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, False,
self.head_size)
self.assertFalse(result)
# Test when head_size is not divisible by 32
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=True):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size + 1)
self.assertFalse(result)
# Test when custom op is disabled
with patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
return_value=False):
result = custom_rotary_embedding_enabled(self.query, True,
self.head_size)
self.assertFalse(result)
class TestRopeForwardOot(TestBase):
def setUp(self):
# Common setup for tests
self.positions = torch.tensor([1, 2, 3])
self.query = torch.randn(3, 4, dtype=torch.float16)
self.key = torch.randn(3, 4, dtype=torch.float16)
self.head_size = 32
self.cos_sin_cache = torch.randn(3, 4)
# Mock self object for rope_forward_oot
self.mock_self = MagicMock()
self.mock_self.head_size = self.head_size
self.mock_self.cos_sin_cache = self.cos_sin_cache
self.mock_self.is_neox_style = True
self.mock_self.forward_native.return_value = (self.query, self.key)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
def test_rope_forward_oot_torchair_enabled_base(self,
mock_get_ascend_config):
# Setup mock for torchair enabled
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = True
mock_get_ascend_config.return_value = mock_config
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.mock_self.forward_native.assert_called_once_with(
self.positions, self.query, self.key, None)
self.assertTrue(torch.equal(result_q, self.query))
self.assertTrue(torch.equal(result_k, self.key))
@patch('torch.ops._C')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
return_value=False)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=True)
@patch('torch.ops._npu_rotary_embedding')
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
mock_custom_enabled, mock_is_310p,
mock_get_ascend_config, mock__c):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Setup mock for custom kernel path
mock__c.rotary_embedding.return_value = self.query, self.key
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
self.query, self.key)
self.assertEqual(result_q.shape, self.query.shape)
self.assertEqual(result_k.shape, self.key.shape)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test contiguous path when custom is disabled
non_contig_query = self.query.transpose(0, 1)
non_contig_key = self.key.transpose(0, 1)
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
non_contig_query, non_contig_key)
mock_npu_rotary.assert_called_once()
self.assertEqual(result_q.shape, non_contig_query.shape)
self.assertEqual(result_k.shape, non_contig_key.shape)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test that NotImplementedError is raised when offsets is provided
offsets = torch.tensor([1, 2, 3])
with self.assertRaises(NotImplementedError):
rope_forward_oot(self.mock_self, self.positions, self.query,
self.key, offsets)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
return_value=False)
@patch('torch_npu._npu_rotary_embedding')
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
mock_custom_enabled,
mock_get_ascend_config):
mock_config = MagicMock()
mock_config.torchair_graph_config.enabled = False
mock_get_ascend_config.return_value = mock_config
# Test neox_style override
result_q, result_k = rope_forward_oot(self.mock_self,
self.positions,
self.query,
self.key,
is_neox_style_override=False)
# Check that neox_style=False was passed to the NPU function
args, kwargs = mock_npu_rotary.call_args
self.assertFalse(args[-1])
class MockRopeModule:
def __init__(self, max_seq_len=2048, is_neox_style=True):
self.max_seq_len = max_seq_len
self.is_neox_style = is_neox_style
self.cos_cached = None
self.sin_cached = None
self.rotary_dim = 1
self.base = 1
class TestNativeRopeDeepseekForward(TestBase):
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_cache_handling(
self, mock_rope_forward_oot, mock_set_cache):
# Test cache situation is true
module = MockRopeModule(max_seq_len=1024)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module,
positions,
query,
key,
max_seq_len=2048)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_key_reshaping(
self, mock_rope_forward_oot):
module = MockRopeModule()
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == (1, 128)
@patch(
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
def test_native_rope_deepseek_forward_non_neox_style(
self, mock_rope_forward_oot):
module = MockRopeModule(is_neox_style=False)
positions = torch.tensor([1, 2, 3])
query = torch.randn(1, 8, 128)
key = torch.randn(1, 8, 128)
mock_rope_forward_oot.return_value = (query, key)
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
key)
assert q_pe.shape == query.shape
assert k_pe.shape == key.shape
class TestRotateHalf(TestBase):
def test_rotate_half_even_dim(self):
# Test with even dimension
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
result = rotate_half(x)
self.assertTrue(torch.allclose(result, expected))
class TestYarnFindCorrectionDim(TestBase):
def test_basic_case(self):
# Test with standard values
num_rotations = 100
dim = 512
base = 10000
max_position_embeddings = 2048
result = yarn_find_correction_dim(num_rotations, dim, base,
max_position_embeddings)
# Calculate expected value manually
expected = (dim * torch.log(
torch.tensor(max_position_embeddings) /
(num_rotations * 2 * torch.pi))) / (2 *
torch.log(torch.tensor(base)))
self.assertTrue(torch.allclose(result, expected))
class TestYarnGetMscale(TestBase):
def test_scale_less_than_or_equal_1(self):
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
def test_scale_greater_than_1(self):
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
(math.e, 1.0, 1.0 + 0.1)]
for scale, mscale, expected in test_cases:
result = yarn_get_mscale(scale, mscale)
self.assertAlmostEqual(
result,
expected,
places=6,
msg=f"Failed for scale={scale}, mscale={mscale}")

View File

@@ -0,0 +1,176 @@
import copy
from unittest.mock import Mock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
TorchairAscendW4A8DynamicFusedMoEMethod,
TorchairAscendW4A8DynamicLinearMethod)
class TestAscendW4A8DynamicLinearMethod(TestBase):
def setUp(self):
self.method = TorchairAscendW4A8DynamicLinearMethod()
self.method.group_size = 8
def test_get_weight(self):
weight = self.method.get_weight(8, 32, torch.bfloat16)
self.assertEqual(weight["weight"].dtype, torch.int8)
self.assertEqual(weight["weight"].shape, (32, 8))
def test_get_pergroup_param(self):
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale"].shape, (32, 1))
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset"].shape, (32, 1))
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
experts = 8
input_size = 16
output_size = 56
group_size = 2
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
)
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
@patch("vllm_ascend.ascend_config.get_ascend_config")
@patch(
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
)
@patch('torch.distributed.get_rank', return_value=0)
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
mock_get_ep_group, get_current_vllm_config):
mock_ascend_config = Mock()
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
mock_get_ascend_config.return_value = mock_ascend_config
mock_vllm_config = Mock()
mock_vllm_config.quant_config = Mock(quant_description={
"group_size": self.group_size,
"version": "0.0.0"
})
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
get_current_vllm_config.return_value = mock_vllm_config
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
def test_get_weight(self):
# old quant version w4a8 weight
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, 2 * self.input_size, self.output_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_weight(self.experts,
self.input_size,
self.output_size,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
self.assertEqual(param_dict["w13_weight"].shape,
(self.experts, self.input_size, self.output_size))
def test_get_dynamic_quant_param(self):
# old quant version weight
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale"].shape,
(self.experts, 2 * self.input_size, 1))
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
(self.experts, 2 * self.input_size,
self.output_size // self.group_size))
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale"].shape,
(self.experts, self.output_size, 1))
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
torch.bfloat16)
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
(self.experts, self.output_size,
self.input_size // self.group_size))
# new quant version weight
self.quant_method.new_quant_version = True
param_dict = self.quant_method.get_dynamic_quant_param(
self.experts, self.input_size, self.output_size, torch.bfloat16)
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
self.assertEqual(
param_dict["w2_scale_bias"].shape,
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
@patch('torch_npu.npu_quantize')
@patch('torch.Tensor.npu')
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
# old quant version weight
layer = torch.nn.Module()
layer.w13_weight = torch.nn.Parameter(torch.zeros(
(self.experts, 2 * self.input_size, self.output_size),
dtype=torch.int8),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(torch.zeros(
(self.experts, self.output_size, self.input_size),
dtype=torch.int8),
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, 2 * self.input_size,
self.output_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
requires_grad=False)
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
(self.experts, self.output_size,
self.input_size // self.group_size),
dtype=torch.bfloat16),
requires_grad=False)
new_layer = copy.deepcopy(layer)
mock_npu.return_value = torch.Tensor()
mock_npu_quantize.return_value = torch.Tensor()
self.quant_method.process_weights_after_loading(layer)
self.assertTrue(hasattr(layer, "w13_scale_bias"))
self.assertEqual(layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
self.assertTrue(hasattr(layer, "w2_scale_bias"))
self.assertEqual(layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
# new quant version weight
self.quant_method.new_quant_version = True
new_layer.w13_weight.data = torch.zeros(
(self.experts, self.input_size, self.output_size),
dtype=torch.int8)
new_layer.w2_weight.data = torch.zeros(
(self.experts, self.output_size // 2, self.input_size),
dtype=torch.int8)
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
dtype=torch.float32)
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
requires_grad=False)
w2_scale_bias = torch.zeros(
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
dtype=torch.float32)
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
requires_grad=False)
self.quant_method.process_weights_after_loading(new_layer)
self.assertEqual(new_layer.w13_scale_bias.data.shape,
(self.experts, 2 * self.input_size))
self.assertEqual(new_layer.w2_scale_bias.data.shape,
(self.experts, self.output_size))

View File

@@ -0,0 +1,75 @@
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
torchair_fused_experts_with_all2all
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.hidden_size = 128
self.num_tokens = 128
self.placeholder = torch.randn(self.num_tokens,
self.hidden_size,
dtype=torch.bfloat16)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")
@patch("torch_npu.npu_swiglu")
@patch("torch_npu.npu_dynamic_quant")
@patch("torch_npu.npu_moe_finalize_routing")
@patch("torch_npu.npu_moe_init_routing")
def test_torchair_fused_experts_with_all2all(
self, mock_moe_init_routing, mock_moe_finalize_routing,
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
mock_moe_re_routing, mock_all_to_all_single):
expert_map = MagicMock()
ep_group = MagicMock()
placeholder_int8 = torch.randint(0,
100,
(self.num_tokens, self.hidden_size),
dtype=torch.int8)
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)
mock_moe_init_routing.return_value = (
placeholder_int8,
placeholder_ones,
placeholder_ones,
)
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
torch.randint(0,
100,
(self.num_tokens, ),
dtype=torch.int32),
self.placeholder)
mock_grouped_matmul.return_value = self.placeholder
mock_swiglu.return_value = self.placeholder
mock_dynamic_quant.return_value = (
placeholder_int8,
torch.randn(self.num_tokens),
)
mock_moe_finalize_routing.return_value = self.placeholder
result = torchair_fused_experts_with_all2all(
hidden_states=self.placeholder,
w1=self.placeholder,
w1_scale=self.placeholder,
w2=self.placeholder,
w2_scale=self.placeholder,
topk_weights=self.placeholder,
topk_ids=self.placeholder,
top_k=8,
expert_map=expert_map,
ep_group=ep_group,
log2phy=None,
global_redundant_expert_num=256,
)
self.assertIsNotNone(result)
self.assertEqual(result.dtype, torch.bfloat16)
self.assertEqual(result.shape, (128, 128))

View File

@@ -0,0 +1,817 @@
from unittest.mock import MagicMock, patch
import torch
from torch import nn
from vllm.distributed.parallel_state import GroupCoordinator
from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.torchair.torchair_mla import (
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
AscendMLATorchairImpl, AscendMLATorchairMetadata,
AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata)
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
class TestAscendMLATorchairBackend(TestBase):
def test_get_name(self):
self.assertEqual(AscendMLATorchairBackend.get_name(),
"ASCEND_MLA_TORCHAIR")
def test_get_metadata_cls(self):
self.assertEqual(AscendMLATorchairBackend.get_metadata_cls(),
AscendMLATorchairMetadata)
def test_get_builder_cls(self):
self.assertEqual(AscendMLATorchairBackend.get_builder_cls(),
AscendMLATorchairMetadataBuilder)
def test_get_kv_cache_shape(self):
result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
self.assertEqual(result, (2, 4, 8, 128))
def test_get_impl_cls(self):
result = AscendMLATorchairBackend.get_impl_cls()
self.assertEqual(result, AscendMLATorchairImpl)
class TestAscendMLATorchairPrefillMetadata(TestBase):
def test_ascend_mla_prefill_metadata_default(self):
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
query_lens = [1, 2]
seq_lens = [2, 2]
context_lens = torch.tensor([1, 2])
input_positions = torch.tensor([0, 1, 0, 1])
query_start_loc = torch.tensor([0, 1, 3])
block_table = torch.tensor([[0, 1], [2, 3]])
max_query_len = 2
max_seq_lens = 2
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=attn_mask,
query_lens=query_lens,
seq_lens=seq_lens,
context_lens=context_lens,
input_positions=input_positions,
query_start_loc=query_start_loc,
block_table=block_table,
max_query_len=max_query_len,
max_seq_lens=max_seq_lens)
self.assertIs(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.context_lens, context_lens)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertIs(metadata.block_table, block_table)
self.assertEqual(metadata.max_query_len, max_query_len)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertIsNone(metadata.chunked_context)
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
cu_seq_lens = torch.tensor([0, 2, 4])
starts = torch.tensor([0, 2])
seq_tot = [2, 2]
max_seq_lens = [2, 2]
workspace = torch.randn(2, 4)
chunk_seq_lens = torch.tensor([2, 2])
chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata(
cu_seq_lens=cu_seq_lens,
starts=starts,
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
metadata = AscendMLATorchairPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
query_lens=[1, 2],
seq_lens=[2, 2],
context_lens=torch.tensor([1, 2]),
input_positions=torch.tensor([0, 1, 0, 1]),
query_start_loc=torch.tensor([0, 1, 3]),
block_table=torch.tensor([[0, 1], [2, 3]]),
max_query_len=2,
max_seq_lens=2,
chunked_context=chunked_context)
self.assertIsNotNone(metadata.chunked_context)
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
self.assertIs(metadata.chunked_context.starts, starts)
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
self.assertIs(metadata.chunked_context.workspace, workspace)
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
class TestAscendMLATorchairDecodeMetadata(TestBase):
def test_ascend_mla_decode_metadata_default(self):
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
seq_lens = torch.tensor([[2], [3]])
max_seq_lens = 4
seq_lens_list = [2, 3]
attn_mask = None
metadata = AscendMLATorchairDecodeMetadata(input_positions,
block_table, seq_lens,
max_seq_lens, seq_lens_list,
attn_mask)
self.assertIs(metadata.input_positions, input_positions)
self.assertIs(metadata.block_table, block_table)
self.assertIs(metadata.seq_lens, seq_lens)
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
self.assertIsNone(attn_mask)
class TestAscendMLATorchairMetadata(TestBase):
def test_ascend_mla_metadata_default(self):
num_actual_tokens = 100
slot_mapping = torch.randn(100, 4, 1024)
query_start_loc = torch.tensor([1, 2, 3, 4])
seq_lens = [30, 50]
block_tables = torch.randint(0, 100, (100, 4))
num_decodes = 4
num_decode_tokens = 8
num_prefills = 8
num_input_tokens = 2
query_lens = None
head_dim = None
attn_mask = None
attn_state = AscendAttentionState.ChunkedPrefill
decode = None
prefill = None
metadata = AscendMLATorchairMetadata(
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
block_tables, num_decodes, num_decode_tokens, num_prefills,
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
decode, prefill)
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
self.assertIs(metadata.slot_mapping, slot_mapping)
self.assertIs(metadata.query_start_loc, query_start_loc)
self.assertEqual(metadata.seq_lens, seq_lens)
self.assertIs(metadata.block_tables, block_tables)
self.assertEqual(metadata.num_decodes, num_decodes)
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
self.assertEqual(metadata.num_prefills, num_prefills)
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
self.assertEqual(metadata.query_lens, query_lens)
self.assertEqual(metadata.head_dim, head_dim)
self.assertEqual(metadata.attn_mask, attn_mask)
self.assertEqual(metadata.attn_state, attn_state)
self.assertEqual(metadata.decode, decode)
self.assertEqual(metadata.prefill, prefill)
class TestAscendMLATorchairMetadataBuilder(TestBase):
def test_ascend_mla_metadata_builder_default(self):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.model_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
mock_device)
self.assertEqual(builder.block_size,
mock_vllm_config.cache_config.block_size)
self.assertEqual(
builder.chunked_prefill_enabled,
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
self.assertEqual(builder.torchair_graph_enabled, True)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_reorder_batch_with_torchair_graph(self, ascend_config):
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = True
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
scheduler_output.scheduled_spec_decode_tokens = {
0: [1],
1: [],
2: [1, 1],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertFalse(modified)
input_batch.swap_states.assert_not_called()
def test_reorder_batch_without_torchair_graph(self):
ascend_config = MagicMock()
ascend_config.torchair_graph_config = MagicMock()
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.max_num_seqs = 4
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
return_value=ascend_config):
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
mock_device)
input_batch = MagicMock()
input_batch.req_ids = [0, 1, 2, 3]
scheduler_output = MagicMock()
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
scheduler_output.scheduled_spec_decode_tokens = {
0: [],
1: [1],
2: [],
3: []
}
input_batch.swap_states = MagicMock()
modified = builder.reorder_batch(input_batch, scheduler_output)
self.assertTrue(modified)
input_batch.swap_states.assert_called_once_with(1, 2)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 64
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 4)
self.assertTrue(torch.equal(result, block_tables[:, :4]))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_get_graph_runner_block_tables_from_numpy(self,
mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_device = 'cpu'
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
mock_device)
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
result = builder._get_graph_runner_block_tables(3, block_tables)
self.assertEqual(result.shape[0], 3)
self.assertEqual(result.shape[1], 64)
self.assertTrue(torch.equal(result[:, :10], block_tables))
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_dummy(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_device = 'cpu'
builder = AscendMLATorchairMetadataBuilder(
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = TorchairCommonAttentionMetadata(
num_reqs=3,
num_actual_tokens=3,
decode_token_per_req=1,
actual_seq_lengths_q=[0, 1, 2],
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
)
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
sin_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
cos_golden = torch.ones(3,
1,
1,
64,
dtype=torch.float16,
device=mock_device)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 3)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 1)
self.assertEqual(metadata.num_decode_tokens, 1)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 64)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 3)
assert torch.equal(sin_golden, metadata.decode.sin)
assert torch.equal(cos_golden, metadata.decode.cos)
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def test_build_decode(self, mock_ascend_config):
ascend_config = MagicMock()
mock_ascend_config.return_value = ascend_config
ascend_config.torchair_graph_config.enabled = False
mock_vllm_config = MagicMock()
mock_vllm_config.model_config.max_model_len = 1024
mock_vllm_config.cache_config.block_size = 16
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
mock_vllm_config.get_head_size.return_value = 64
mock_vllm_config.model_config.dtype = torch.float16
mock_device = 'cpu'
model = MagicMock(spec=nn.Module)
model.model = MagicMock(spec=nn.Module)
builder = AscendMLATorchairMetadataBuilder(
mock_vllm_config,
mock_device,
metadata_cls=AscendMLATorchairMetadata)
builder.rope_dim = 64
builder.sin_cache = torch.tensor([10, 10])
builder.cos_cache = torch.tensor([10, 10])
with patch.object(builder,
"_get_graph_runner_block_tables",
side_effect=lambda x, y: y):
common_attn_metadata = AscendCommonAttentionMetadata(
query_start_loc=torch.tensor([0, 1, 2, 3]),
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
seq_lens_cpu=torch.tensor([1, 1, 1]),
num_reqs=3,
num_actual_tokens=3,
max_query_len=1,
decode_token_per_req=torch.tensor([1, 1, 1]),
block_table_tensor=torch.zeros((10, 10)),
slot_mapping_cpu=torch.tensor(range(20)),
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
positions=torch.tensor([1, 1]),
attn_mask=torch.ones((15, 15)),
spec_attn_mask=None,
attn_state=AscendAttentionState.ChunkedPrefill)
metadata = builder.build(common_attn_metadata, model)
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
self.assertEqual(metadata.num_input_tokens, 0)
self.assertEqual(metadata.num_actual_tokens, 3)
self.assertEqual(metadata.num_decodes, 3)
self.assertEqual(metadata.num_decode_tokens, 3)
self.assertEqual(metadata.num_prefills, 0)
self.assertEqual(metadata.attn_state,
AscendAttentionState.ChunkedPrefill)
self.assertIsNone(metadata.prefill)
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
self.assertEqual(metadata.block_tables.shape[0], 3)
self.assertEqual(metadata.block_tables.shape[1], 10)
self.assertEqual(metadata.seq_lens.shape[0], 3)
self.assertEqual(metadata.slot_mapping.shape[0], 3)
self.assertEqual(metadata.query_start_loc.shape[0], 4)
class TestAscendMLATorchairImpl(TestBase):
@patch('vllm.distributed.parallel_state._TP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
return_value=2)
@patch("vllm.config.get_current_vllm_config")
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
mock_tp.world_size = 2
ascend_config.torchair_graph_config.enabled = True
ascend_config.torchair_graph_config.enable_kv_nz = False
speculative_config = MagicMock()
speculative_config.num_speculative_tokens = 4
vllm_config.speculative_config = speculative_config
num_heads = 256
head_size = 1024
scale = 0.1
num_kv_heads = 8
kv_cache_dtype = "auto"
kv_a_layernorm = MagicMock()
kv_a_layernorm.weight = torch.randn(96)
kv_a_layernorm.variance_epsilon = 1e-6
kwargs = {
"q_lora_rank": 64,
"kv_lora_rank": 32,
"qk_nope_head_dim": 64,
"qk_rope_head_dim": 32,
"qk_head_dim": 96,
"v_head_dim": 128,
"rotary_emb": MagicMock(),
"q_proj": MagicMock(),
"kv_b_proj": MagicMock(),
"o_proj": MagicMock(),
"kv_a_proj_with_mqa": MagicMock(),
"kv_a_layernorm": kv_a_layernorm,
}
self.impl = AscendMLATorchairImpl(num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=None,
sliding_window=None,
kv_cache_dtype=kv_cache_dtype,
blocksparse_params=None,
logits_soft_cap=None,
attn_type=None,
kv_sharing_target_layer_name=None,
**kwargs)
def test_init(self):
self.assertEqual(self.impl.num_heads, 256)
self.assertEqual(self.impl.head_size, 1024)
self.assertEqual(self.impl.scale, 0.1)
self.assertEqual(self.impl.num_kv_heads, 8)
self.assertEqual(self.impl.kv_cache_dtype, "auto")
self.assertEqual(self.impl.q_lora_rank, 64)
self.assertEqual(self.impl.kv_lora_rank, 32)
self.assertEqual(self.impl.qk_nope_head_dim, 64)
self.assertEqual(self.impl.qk_rope_head_dim, 32)
self.assertEqual(self.impl.qk_head_dim, 96)
self.assertEqual(self.impl.v_head_dim, 128)
self.assertIsNotNone(self.impl.rotary_emb)
self.assertIsNotNone(self.impl.q_proj)
self.assertIsNotNone(self.impl.kv_b_proj)
self.assertIsNotNone(self.impl.o_proj)
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
self.assertIsNotNone(self.impl.kv_a_layernorm)
self.assertEqual(self.impl.num_queries_per_kv, 32)
self.assertEqual(self.impl.tp_size, 2)
self.assertTrue(self.impl.torchair_graph_enabled)
def test_v_up_proj_and_o_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads,
self.impl.kv_lora_rank)
self.impl.o_proj.return_value = (torch.randn(
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
self.impl.W_UV = torch.randn(self.impl.num_heads,
self.impl.kv_lora_rank,
self.impl.v_head_dim)
result = self.impl._v_up_proj_and_o_proj(x)
self.assertEqual(result.shape[0], batch_size)
self.assertEqual(result.shape[1],
self.impl.num_heads * self.impl.v_head_dim)
def test_q_proj_and_k_up_proj(self):
batch_size = 4
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
self.impl.q_proj.return_value = (q_proj_output, )
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
self.impl.qk_nope_head_dim,
self.impl.kv_lora_rank)
result = self.impl._q_proj_and_k_up_proj(x)
ql_nope, q_pe = result
self.assertEqual(ql_nope.shape[0], batch_size)
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
self.assertEqual(q_pe.shape[0], batch_size)
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
def test_process_weights_after_loading(self):
layer = MagicMock(spec=LinearBase)
layer.input_size_per_partition = 10
quant_method = MagicMock()
apply = MagicMock()
quant_method.apply = apply
layer.quant_method = quant_method
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
self.impl.v_head_dim)
shape_1 = self.impl.kv_lora_rank
layer.weight = torch.randn(shape_0, shape_1)
self.impl.kv_b_proj = layer
apply.return_value = layer.weight.T
self.impl.process_weights_after_loading(torch.bfloat16)
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
def test_compute_prefill_context_none(self):
batch_size = 4
kv_cache = torch.randn(10, 1, 1, 192)
query = torch.randn(batch_size, self.impl.num_heads,
self.impl.qk_head_dim)
metadata = MagicMock()
metadata.prefill = None
prefix_out = torch.randn(2, 16, 128)
prefix_lse = torch.randn(2, 16, 8)
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
metadata, prefix_out,
prefix_lse)
self.assertTrue(torch.equal(prefix_out, out))
self.assertTrue(torch.equal(prefix_lse, lse))
@patch("torch_npu.atb.npu_paged_cache_load")
@patch("torch_npu.atb.npu_ring_mla")
def test_compute_prefill_context(self, mock_ring, mock_load):
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
latent_kv_dim = self.impl.kv_lora_rank
num_blocks, block_size = 100, 20
query = torch.randn(S, N, D)
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
kv_cache = [kv_cache_0, kv_cache_1]
prefix_out = torch.randn(S, N, 128)
prefix_lse = torch.randn(S, N)
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
chunk_ctx = MagicMock()
chunk_ctx.seq_tot = [8]
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
chunk_ctx.starts = [torch.tensor([0])]
prefill_meta = MagicMock()
prefill_meta.chunked_context = chunk_ctx
prefill_meta.query_lens = [8]
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
meta = MagicMock()
meta.prefill = prefill_meta
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
meta, prefix_out,
prefix_lse)
mock_load.assert_called_once()
mock_ring.assert_called_once()
self.assertEqual(out.shape, prefix_out.shape)
self.assertEqual(lse.shape, prefix_lse.shape)
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv(self, mock_kv_cache):
batch_size = 2
hidden = torch.randn(batch_size, 128)
cos = torch.randn(batch_size, 32)
sin = torch.randn(batch_size, 32)
kv_cache = (torch.randn(
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(
4, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
slots = torch.arange(batch_size, dtype=torch.long)
proj_out = torch.randn(
batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
mock_kv_cache.return_value = (torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim),
torch.randn(batch_size,
self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank),
None, None)
k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
mock_kv_cache.assert_called_once()
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
self.assertEqual(kv.shape,
(batch_size, self.impl.num_kv_heads, 1,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
def test_exec_kv_prefill(self, mock_kv):
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
hidden_states = torch.randn(B, N, S, H)
cos = torch.randn(B, S, 32)
sin = torch.randn(B, S, 32)
kv_cache = (
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
torch.randn(100, 8,
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
)
slots = torch.arange(B * S, dtype=torch.long)
proj_out = torch.randn(
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
mock_kv.return_value = (None, None,
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.qk_rope_head_dim),
torch.randn(B, self.impl.num_kv_heads, S,
self.impl.kv_lora_rank))
k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
kv_cache, slots)
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
mock_kv.assert_called_once()
self.assertEqual(
k_pe.shape,
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
self.assertEqual(
k_nope.shape,
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))
@patch("torch_npu.npu_interleave_rope")
def test_rope_single(self, mock_rope):
B, N, D = 2, 16, 1024
x = torch.randn(B, N, D)
cos = torch.randn(B, N, 1, D)
sin = torch.randn(B, N, 1, D)
mock_rope.return_value = x.view(B, N, 1, D)
result = self.impl.rope_single(x, cos, sin)
self.assertEqual(result.shape[0], B)
self.assertEqual(result.shape[1], N)
self.assertEqual(result.shape[2], D)
mock_rope.assert_called_once()
@patch(
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj"
)
@patch("torch_npu._npu_paged_attention_mla")
def test_forward_decode_without_graph(self, mock_page_attention_mla,
mock_up_proj):
self.impl.running_in_graph = False
self.impl.running_chunkprefilll_with_torchair = False
num_tokens = 100
num_blocks = 256
block_size = 4
q_nope = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_nope_head_dim)
q_pe = torch.randn(num_tokens, self.impl.num_heads,
self.impl.qk_rope_head_dim)
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
self.impl.num_heads,
self.impl.kv_lora_rank)
metadata = MagicMock()
metadata.decode = MagicMock()
metadata.decode.block_table = MagicMock()
metadata.decode.seq_lens = 10
mock_page_attention_mla.return_value = torch.randn(
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
mock_up_proj.return_value = torch.randn(num_tokens,
self.impl.num_heads,
self.impl.v_head_dim)
result = self.impl._forward_decode(q_nope, q_pe, None, None,
kv_c_and_k_pe_cache, metadata)
self.assertEqual(result.shape[0], num_tokens)
self.assertEqual(result.shape[1], self.impl.num_heads)
self.assertEqual(result.shape[2], self.impl.v_head_dim)
mock_up_proj.assert_called_once()
mock_page_attention_mla.assert_called_once()
@patch(
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill"
)
@patch("torch_npu._npu_reshape_and_cache")
def test_forward_without_graph(self, _, mock_forward_prefill):
self.impl.running_in_graph = False
self.impl.torchair_graph_enabled = False
num_tokens = 100
num_blocks = 256
block_size = 4
rotary_emb_return_value = (torch.randn(num_tokens, 16,
self.impl.kv_lora_rank),
torch.randn(0, 1, self.impl.kv_lora_rank))
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
1, num_blocks, 128)
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
self.impl.kv_lora_rank)
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.kv_lora_rank),
torch.randn(num_blocks, block_size, self.impl.num_heads,
self.impl.qk_rope_head_dim))
output = torch.randn(num_tokens, self.impl.num_heads,
self.impl.v_head_dim)
metadata = MagicMock()
metadata.num_decodes = 0
metadata.num_prefills = num_tokens
mock_forward_prefill.return_value = torch.randn(
0, self.impl.num_heads * self.impl.v_head_dim)
result = self.impl.forward(None, hidden_states_or_q_c,
hidden_states_or_kv_c_normed, k_pe,
kv_cache, metadata, output, False)
self.assertEqual(result.shape[0], num_tokens)

View File

@@ -0,0 +1,149 @@
import os
from concurrent.futures import ThreadPoolExecutor
from unittest import mock
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
from vllm_ascend.torchair import utils
class TestTorchairUtils(TestBase):
def test_get_torchair_current_work_dir(self):
cache_dir = utils.TORCHAIR_CACHE_DIR
work_dir = utils._get_torchair_current_work_dir()
self.assertEqual(cache_dir, work_dir)
work_dir = utils._get_torchair_current_work_dir("test")
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
def test_torchair_cache_dir(self):
utils.write_kv_cache_bytes_to_file(0, 100)
self.assertTrue(utils.check_torchair_cache_exist(),
"Create torchair cache dir failed")
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
"Create kv cache bytes cache dir failed")
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
self.assertEqual(100, kv_cache_bytes)
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
def test_torchair_cache_dir_multiple_ranks(self):
ranks = [0, 1, 2, 3]
values = [100, 200, 300, 400]
with ThreadPoolExecutor() as executor:
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
for rank, expected in zip(ranks, values):
self.assertEqual(expected,
utils.read_kv_cache_bytes_from_file(rank))
utils.delete_torchair_cache_file()
self.assertFalse(utils.check_torchair_cache_exist(),
"Delete torchair cache dir failed")
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
"Delete kv cache bytes cache dir failed")
def test_delete_torchair_cache_file_multiple_times(self):
utils.write_kv_cache_bytes_to_file(0, 100)
utils.delete_torchair_cache_file()
for i in range(5):
try:
utils.delete_torchair_cache_file()
except FileNotFoundError:
self.fail(
f"Unexpected FileNotFoundError on delete call #{i+2}")
@patch('vllm.ModelRegistry')
def test_register_torchair_model(self, mock_model_registry):
mock_registry = MagicMock()
mock_model_registry.return_value = mock_registry
utils.register_torchair_model()
self.assertEqual(mock_model_registry.register_model.call_count, 6)
call_args_list = mock_model_registry.register_model.call_args_list
expected_registrations = [
("DeepSeekMTPModel",
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
),
("DeepseekV2ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
),
("DeepseekV3ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
),
("Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
("Qwen3MoeForCausalLM",
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
),
("PanguProMoEForCausalLM",
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
)
]
for i, (expected_name,
expected_path) in enumerate(expected_registrations):
args, kwargs = call_args_list[i]
self.assertEqual(args[0], expected_name)
self.assertEqual(args[1], expected_path)
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format(self, mock_npu_cast,
mock_get_format):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = 1
mock_npu_cast.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
self.assertEqual(fused_moe.w13_weight.data, 1)
@mock.patch('torch_npu.get_npu_format')
@mock.patch('torch_npu.npu_format_cast')
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
new=mock.MagicMock)
def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
mock_get_format):
ACL_FORMAT_FRACTAL_NZ = 29
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
mock_npu_cast.return_value = 1
fused_moe = mock.MagicMock()
fused_moe.w13_weight = mock.MagicMock()
fused_moe.w2_weight = mock.MagicMock()
fused_moe.w13_weight.data = torch.randn(128, 256)
fused_moe.w2_weight.data = torch.randn(256, 128)
model = mock.MagicMock()
model.modules.return_value = [fused_moe]
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
mock_npu_cast.assert_not_called()
def test_torchair_quant_method_register(self):
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
"W8A8_DYNAMIC"]
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
"W4A8_DYNAMIC"]
utils.torchair_quant_method_register()
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])

View File

@@ -0,0 +1,372 @@
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import inspect
from collections.abc import Sequence
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.sampling_params import SamplingParams
from vllm.utils import make_tensor_with_pad
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
VOCAB_SIZE = 1024
NUM_OUTPUT_TOKENS = 20
MAX_PROMPT_SIZE = 100
MAX_NUM_PROMPT_TOKENS = 64
def _compare_objs(obj1,
obj2,
skip: Sequence = ("logitsprocs", "batch_update_builder")):
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
attr_names = set([
a[0] for a in attrs
if not (a[0].startswith('__') and a[0].endswith('__'))
])
for attr_name in attr_names:
if attr_name in skip:
continue
a = getattr(obj1, attr_name)
b = getattr(obj2, attr_name)
is_same = False
if isinstance(a, torch.Tensor):
if (a.numel() == 0 or b.numel() == 0):
is_same = (a.numel() == 0 and b.numel() == 0)
elif torch.allclose(a, b):
is_same = True
elif isinstance(a, np.ndarray):
if np.allclose(a, b):
is_same = True
elif isinstance(a, MultiGroupBlockTable):
for a_i, b_i in zip(a.block_tables, b.block_tables):
_compare_objs(a_i, b_i)
is_same = True
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
_compare_objs(a, b)
is_same = True # if we make it here must be same
elif a == b:
is_same = True
assert is_same, f"Attribute {attr_name} is different"\
f" in {obj1} and {obj2}: {a} != {b}"
def _remove_requests(input_batch: InputBatch, batch_size: int,
reqs: list[CachedRequestState]) -> set[str]:
"""
Remove some requests randomly from the batch and returns
set of request removed
"""
num_reqs_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove: set[int] = set()
for _ in range(num_reqs_to_remove):
req_index_to_remove = np.random.randint(0, batch_size)
req_indices_to_remove.add(req_index_to_remove)
req_ids_to_remove: set[str] = set()
for index in req_indices_to_remove:
input_batch.remove_request(reqs[index].req_id)
req_ids_to_remove.add(reqs[index].req_id)
return req_ids_to_remove
def _construct_expected_sampling_metadata(
reqs: list[CachedRequestState],
req_ids_retained: set[int],
req_id_index_in_input_batch: dict[str, int],
device: torch.device,
) -> SamplingMetadata:
"""
Constructs and returns the expected SamplingMetadata for this
batch.
"""
num_reqs = len(req_ids_retained)
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
presence_penalties = [0.0 for _ in range(num_reqs)]
frequency_penalties = [0.0 for _ in range(num_reqs)]
repetition_penalties = [1.0 for _ in range(num_reqs)]
top_k = [0 for _ in range(num_reqs)]
top_p = [0.0 for _ in range(num_reqs)]
temperature = [0.0 for _ in range(num_reqs)]
min_tokens = {}
logit_bias = [None] * num_reqs
allowed_token_ids_mask = torch.zeros(num_reqs,
VOCAB_SIZE,
dtype=torch.bool,
device=device)
bad_words_token_ids = {}
for req in reqs:
if req.req_id not in req_ids_retained:
continue
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
output_token_ids[index_in_input_batch] = req.output_token_ids
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
presence_penalties[
index_in_input_batch] = req.sampling_params.presence_penalty
frequency_penalties[index_in_input_batch] = (
req.sampling_params.frequency_penalty)
repetition_penalties[index_in_input_batch] = (
req.sampling_params.repetition_penalty)
top_k[index_in_input_batch] = req.sampling_params.top_k
top_p[index_in_input_batch] = req.sampling_params.top_p
temperature[index_in_input_batch] = req.sampling_params.temperature
min_tokens[index_in_input_batch] = (
req.sampling_params.min_tokens,
req.sampling_params.all_stop_token_ids)
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
if req.sampling_params.allowed_token_ids:
allowed_token_ids_mask[index_in_input_batch][
req.sampling_params.allowed_token_ids] = True
if req.sampling_params.bad_words_token_ids:
bad_words_token_ids[
index_in_input_batch] = req.sampling_params.bad_words_token_ids
return SamplingMetadata(
temperature=torch.tensor(temperature, dtype=torch.float,
device=device),
all_greedy=False,
all_random=True,
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
top_p, dtype=torch.float, device=device),
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
top_k, dtype=torch.int, device=device),
generators={},
max_num_logprobs=0,
prompt_token_ids=make_tensor_with_pad(
prompt_token_ids,
pad=VOCAB_SIZE,
device=torch.device(device),
dtype=torch.int64,
),
frequency_penalties=torch.tensor(frequency_penalties,
dtype=torch.float,
device=device),
presence_penalties=torch.tensor(presence_penalties,
dtype=torch.float,
device=device),
repetition_penalties=torch.tensor(repetition_penalties,
dtype=torch.float,
device=device),
output_token_ids=output_token_ids,
no_penalties=(all(x == 0 for x in presence_penalties)
and all(x == 0 for x in frequency_penalties)
and all(x == 1 for x in repetition_penalties)),
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=bad_words_token_ids,
logitsprocs=LogitsProcessors(),
)
def _create_sampling_params():
return SamplingParams(
top_k=np.random.randint(1, 10),
top_p=np.random.uniform(0.0, 1.0),
presence_penalty=np.random.uniform(-2.0, 2.0),
repetition_penalty=np.random.uniform(0.0, 2.0),
frequency_penalty=np.random.uniform(-2.0, 2.0),
min_tokens=np.random.randint(1, 10),
stop_token_ids=[
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(10))
],
logit_bias={0: np.random.uniform(-3.0, 3.0)},
)
def _construct_cached_request_state(req_id_suffix: int):
prompt_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
]
output_token_ids = [
np.random.randint(0, VOCAB_SIZE)
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
]
return CachedRequestState(
req_id=f"req_id_{req_id_suffix}",
prompt_token_ids=prompt_token_ids,
sampling_params=_create_sampling_params(),
pooling_params=None,
mm_kwargs=[],
mm_positions=[],
block_ids=([], ),
generator=None,
num_computed_tokens=len(output_token_ids),
output_token_ids=output_token_ids,
mm_hashes=None,
)
@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert req_index == assigned_req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
# Remove some requests
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
# Compact the input batch
input_batch.condense()
# Generate the sampling metadata
sampling_metadata = input_batch._make_sampling_metadata()
# Create expected output.
expected_sampling_metadata = _construct_expected_sampling_metadata(
reqs,
req_ids_retained,
input_batch.req_id_to_index,
device=torch.device(device))
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
return (t1 is None
and t2 is None) or (t1 is not None and t2 is not None
and torch.allclose(t1, t2))
# Assert the actual and expected output.
assert torch.allclose(expected_sampling_metadata.temperature,
sampling_metadata.temperature)
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
assert torch.allclose(
expected_sampling_metadata.frequency_penalties,
sampling_metadata.frequency_penalties,
)
assert torch.allclose(
expected_sampling_metadata.presence_penalties,
sampling_metadata.presence_penalties,
)
assert torch.allclose(
expected_sampling_metadata.repetition_penalties,
sampling_metadata.repetition_penalties,
)
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
sampling_metadata.prompt_token_ids)
assert (expected_sampling_metadata.output_token_ids ==
sampling_metadata.output_token_ids)
assert expected_sampling_metadata.no_penalties == \
sampling_metadata.no_penalties
if sampling_metadata.allowed_token_ids_mask:
assert torch.allclose(
expected_sampling_metadata.allowed_token_ids_mask,
sampling_metadata.allowed_token_ids_mask)
assert expected_sampling_metadata.bad_words_token_ids == \
sampling_metadata.bad_words_token_ids
@pytest.mark.parametrize("device", ["cpu"])
@pytest.mark.parametrize("batch_size", [32])
@pytest.mark.parametrize("swap_list", [((0, 1), )])
def test_swap_states_in_input_batch(device: str, batch_size: int,
swap_list: list):
"""
Tests the logic for managing sampling metadata in the InputBatch.
This test involves adding a set of requests to the InputBatch,
followed by removing a subset of them. Afterward, the batch is compacted,
and the `make_sampling_metadata` method is invoked on the batch. The
output of `make_sampling_metadata` is then compared against the expected
results to ensure correctness.
Note: Ignore logits processor logic, which is tested separately
"""
input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
ref_input_batch: InputBatch = InputBatch(
max_num_reqs=batch_size,
max_model_len=1024,
max_num_batched_tokens=1024,
device=torch.device(device),
pin_memory=False,
vocab_size=1024,
block_sizes=[1],
)
reqs: list[CachedRequestState] = []
req_id_reqs = {}
req_id_output_token_ids = {}
# Add requests
for req_index in range(batch_size):
req: CachedRequestState = _construct_cached_request_state(req_index)
assigned_req_index = input_batch.add_request(req)
assert assigned_req_index == req_index
reqs.append(req)
req_id_reqs[req.req_id] = req
req_id_output_token_ids[req.req_id] = req.output_token_ids
reordered_reqs = reqs.copy()
for swap_pair in swap_list:
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
input_batch.swap_states(swap_pair[0], swap_pair[1])
for req_index in range(batch_size):
req = reordered_reqs[req_index]
assigned_req_index = ref_input_batch.add_request(req)
assert assigned_req_index == req_index
input_batch.refresh_metadata()
ref_input_batch.refresh_metadata()
_compare_objs(input_batch, ref_input_batch)

File diff suppressed because it is too large Load Diff