v0.10.1rc1
This commit is contained in:
0
tests/ut/__init__.py
Normal file
0
tests/ut/__init__.py
Normal file
133
tests/ut/attention/test_attention_mask.py
Normal file
133
tests/ut/attention/test_attention_mask.py
Normal file
@@ -0,0 +1,133 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
|
||||
|
||||
class TestAttentionMaskBuilder(TestBase):
|
||||
|
||||
def test_init_attention_mask_builder(self):
|
||||
# generate attention_mask_builder with float16
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
|
||||
torch.float16)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(1024, 1024))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
# generate attention_mask_builder with bfloat16
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=2048,
|
||||
dtype=torch.bfloat16)
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(2048, 2048))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(1, dtype=torch.bfloat16))
|
||||
|
||||
def test_get_mask_scale_factor(self):
|
||||
# supported data types
|
||||
self.assertEqual(
|
||||
AttentionMaskBuilder.get_mask_scale_factor(torch.float16), 1)
|
||||
self.assertEqual(
|
||||
AttentionMaskBuilder.get_mask_scale_factor(torch.bfloat16), -10000)
|
||||
# mask_scale_factor now only supports data types: torch.float16 and torch.bfloat16
|
||||
# Otherwise raise ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
AttentionMaskBuilder.get_mask_scale_factor(torch.int8)
|
||||
|
||||
def test_get_attn_mask(self):
|
||||
# if the len is less than max_seq_len, the attn_mask_cache will not be updated
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
attn_mask = attention_mask_builder.get_attn_mask(
|
||||
max_seq_len=512, dtype=torch.float16, device=torch.device("cpu"))
|
||||
self.assertEqual(attn_mask.shape, (512, 512))
|
||||
self.assertEqual(attn_mask[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(1024, 1024))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
# if the len is greater than max_seq_len, the attn_mask_cache will be updated
|
||||
attn_mask = attention_mask_builder.get_attn_mask(
|
||||
max_seq_len=2048, dtype=torch.float16, device=torch.device("cpu"))
|
||||
self.assertEqual(attn_mask.shape, (2048, 2048))
|
||||
self.assertEqual(attn_mask[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 2048)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache.shape,
|
||||
(2048, 2048))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[0][-1],
|
||||
torch.tensor(float("-inf"), dtype=torch.float16))
|
||||
|
||||
def test_get_splitfuse_attn_mask(self):
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=1024,
|
||||
dtype=torch.float16)
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=torch.tensor([10, 20, 100]),
|
||||
position=torch.tensor([7, 8, 9, 18, 19, 99]),
|
||||
dtype=torch.float16,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.assertEqual(attn_mask.shape, (6, 100))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 1024)
|
||||
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=torch.tensor([10, 3000, 2000]),
|
||||
position=torch.tensor([7, 8, 9, 2999, 1999]),
|
||||
dtype=torch.float16,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.assertEqual(attn_mask.shape, (5, 3000))
|
||||
self.assertEqual(attention_mask_builder._seq_len_cached, 3000)
|
||||
|
||||
# splitfuse_attn_mask now only supports data types: torch.float16 and torch.bfloat16
|
||||
# otherwise raise ValueError
|
||||
with self.assertRaises(ValueError):
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=torch.tensor([10, 20, 100]),
|
||||
position=torch.tensor([7, 8, 9, 18, 19, 99]),
|
||||
dtype=torch.int8,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
def test_mask_value_cleanliness(self):
|
||||
attention_mask_builder = AttentionMaskBuilder(max_seq_len=6,
|
||||
dtype=torch.bfloat16)
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
|
||||
torch.tensor(1, dtype=torch.bfloat16))
|
||||
|
||||
attn_mask = attention_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens=torch.tensor([6]),
|
||||
position=torch.tensor([3, 4, 5]),
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.device("cpu"),
|
||||
)
|
||||
self.assertEqual(
|
||||
attn_mask[-2][-1],
|
||||
torch.tensor(-10000, dtype=torch.bfloat16,
|
||||
device=attn_mask.device))
|
||||
self.assertEqual(attention_mask_builder.attn_mask_cache[-2][-1],
|
||||
torch.tensor(1, dtype=torch.bfloat16))
|
||||
578
tests/ut/attention/test_attention_v1.py
Normal file
578
tests/ut/attention/test_attention_v1.py
Normal file
@@ -0,0 +1,578 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionBackend,
|
||||
AscendAttentionBackendImpl,
|
||||
AscendAttentionMetadataBuilder,
|
||||
AscendAttentionState,
|
||||
AscendMetadata,
|
||||
CommonAttentionState)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
|
||||
|
||||
class TestAscendAttentionBackend(TestBase):
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_name(), "ASCEND")
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_impl_cls(),
|
||||
AscendAttentionBackendImpl)
|
||||
|
||||
def test_get_metadata_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_metadata_cls(),
|
||||
AscendMetadata)
|
||||
|
||||
def test_get_state_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_state_cls(),
|
||||
CommonAttentionState)
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendAttentionBackend.get_builder_cls(),
|
||||
AscendAttentionMetadataBuilder)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p')
|
||||
def test_get_kv_cache_shape_310p(self, mock_is_310p):
|
||||
mock_is_310p.return_value = True
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 30 * 40 // 16, 20, 16))
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
def test_get_kv_cache_shape_not_310p(self, mock_is_310p):
|
||||
result = AscendAttentionBackend.get_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 20, 30, 40))
|
||||
|
||||
def test_get_bsh_kv_cache_shape(self):
|
||||
result = AscendAttentionBackend.get_bsh_kv_cache_shape(10, 20, 30, 40)
|
||||
self.assertEqual(result, (2, 10, 20, 30 * 40))
|
||||
|
||||
def test_swap_blocks(self):
|
||||
src_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
dst_kv_cache = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
src_to_dst = torch.tensor([[0, 1], [2, 3]])
|
||||
AscendAttentionBackend.swap_blocks(src_kv_cache, dst_kv_cache,
|
||||
src_to_dst)
|
||||
self.assertTrue(torch.all(dst_kv_cache[0][1] == src_kv_cache[0][0]))
|
||||
self.assertTrue(torch.all(dst_kv_cache[1][3] == src_kv_cache[1][2]))
|
||||
|
||||
def test_copy_blocks(self):
|
||||
kv_caches = [torch.zeros((10, 20)), torch.zeros((10, 20))]
|
||||
src_to_dists = torch.tensor([[0, 1], [2, 3]])
|
||||
AscendAttentionBackend.copy_blocks(kv_caches, src_to_dists)
|
||||
self.assertTrue(torch.all(kv_caches[0][1] == kv_caches[0][0]))
|
||||
self.assertTrue(torch.all(kv_caches[1][3] == kv_caches[1][2]))
|
||||
|
||||
|
||||
class TestAscendAttentionMetadataBuilder(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_vllm_config = MagicMock()
|
||||
self.mock_vllm_config.model_config.max_model_len = 640
|
||||
self.mock_vllm_config.cache_config.block_size = 64
|
||||
self.mock_device = 'cpu:0'
|
||||
self.builder = AscendAttentionMetadataBuilder(self.mock_vllm_config,
|
||||
self.mock_device)
|
||||
|
||||
def test_reorder_batch(self):
|
||||
mock_input_batch = MagicMock()
|
||||
mock_scheduler_output = MagicMock()
|
||||
|
||||
result = self.builder.reorder_batch(mock_input_batch,
|
||||
mock_scheduler_output)
|
||||
|
||||
self.assertFalse(result)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_2d')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
def test_build_prefill_no_cache(self, mock_is_310p, mock_nd_to_nz_2d,
|
||||
mock_npu_format_cast,
|
||||
mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 3, 7]),
|
||||
query_start_loc_cpu=torch.tensor([0, 3, 7]),
|
||||
seq_lens_cpu=torch.tensor([5, 6]),
|
||||
num_reqs=2,
|
||||
num_actual_tokens=10,
|
||||
max_query_len=5,
|
||||
decode_token_per_req=torch.tensor([1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((10, 10)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache)
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_2d.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('vllm_ascend.utils.nd_to_nz_spec')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendAttentionState')
|
||||
def test_build_chunked_prefill(self, mock_ascend_attention_state,
|
||||
mock_is_310p, mock_nd_to_nz_spec,
|
||||
mock_npu_format_cast, mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
||||
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=15,
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||
|
||||
mock_ascend_attention_state = MagicMock()
|
||||
mock_ascend_attention_state.PrefillNoCache = 0
|
||||
|
||||
mock_nz_tensor = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_nd_to_nz_spec.return_value = mock_nz_tensor
|
||||
mock_npu_format_cast.return_value = mock_nz_tensor
|
||||
|
||||
self.builder.build(common_attn_metadata, mock_model)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.AscendMetadata')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
def test_build_non_310p(self, mock_is_310p, mock_ascend_metadata):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 2, 5, 9]),
|
||||
query_start_loc_cpu=torch.tensor([0, 2, 5, 9]),
|
||||
seq_lens_cpu=torch.tensor([4, 5, 6]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=15,
|
||||
max_query_len=6,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([10, 10]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||
mock_model = MagicMock()
|
||||
|
||||
self.builder.build(common_attn_metadata, mock_model)
|
||||
|
||||
|
||||
class TestAscendAttentionBackendImpl(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.layer_name = "test_layer"
|
||||
self.layer._k_scale_float = 1.0
|
||||
self.layer._v_scale_float = 1.0
|
||||
|
||||
self.attention_type = MagicMock()
|
||||
self.attention_type.DECODER = "decoder"
|
||||
self.attention_type.ENCODER = "encoder"
|
||||
|
||||
self.attn_metadata = MagicMock()
|
||||
self.attn_metadata.return_value = "1"
|
||||
|
||||
self.layer_no_quant = MagicMock(
|
||||
spec=['layer_name', '_k_scale_float', '_v_scale_float'])
|
||||
self.layer_no_quant.layer_name = "test_layer"
|
||||
self.layer_no_quant._k_scale_float = 1.0
|
||||
self.layer_no_quant._v_scale_float = 1.0
|
||||
|
||||
self.impl = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
logits_soft_cap=None,
|
||||
attn_type=self.attention_type.DECODER,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
self.impl_192 = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=192,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
logits_soft_cap=None,
|
||||
attn_type=self.attention_type.DECODER,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
self.impl_error = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=192,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype="float16",
|
||||
logits_soft_cap=None,
|
||||
attn_type=None,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
self.impl_swa = AscendAttentionBackendImpl(
|
||||
num_heads=8,
|
||||
head_size=64,
|
||||
scale=1.0,
|
||||
num_kv_heads=8,
|
||||
alibi_slopes=None,
|
||||
sliding_window=1024,
|
||||
kv_cache_dtype="float16",
|
||||
logits_soft_cap=None,
|
||||
attn_type=self.attention_type.DECODER,
|
||||
kv_sharing_target_layer_name=None)
|
||||
|
||||
@patch('torch.ops.vllm.unified_ascend_attention_with_output')
|
||||
def test_forward_trace_flag_true(self, mock_unified_attention):
|
||||
"""Test forward pass when trace_flag is True"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
layer = self.layer
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=True)
|
||||
|
||||
mock_unified_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
def test_forward_with_quant_method(self, mock_paged_attention):
|
||||
"""Test forward pass when layer has quant_method"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
k_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
v_cache = torch.ones(1, 10, 8, 64, dtype=torch.int8)
|
||||
kv_cache = [k_cache, v_cache]
|
||||
ret_value = torch.ones(1, 1, 10, 8, 64, dtype=torch.int8)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_actual_tokens = torch.randn(10, 8 * 64)
|
||||
metadata.block_tables = torch.randn(10, 8 * 64)
|
||||
metadata.seq_lens = torch.randn(10, 8 * 64)
|
||||
metadata.attn_mask = torch.randn(10, 8 * 64)
|
||||
metadata.query_lens = torch.randn(10, 8 * 64)
|
||||
layer = self.layer
|
||||
layer.quant_method = MagicMock()
|
||||
layer.quant_method.apply.return_value = ret_value
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
layer.quant_method.apply.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
def test_forward_no_attn_metadata(self):
|
||||
"""Test forward pass when attn_metadata is None"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 0, 0, 8, 64)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
None,
|
||||
trace_flag=False)
|
||||
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
# layer.quant_method.apply.return_value = metadata
|
||||
print(self.layer_no_quant._v_scale_float)
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_forward_prefill_no_cache_swa(self, mock_flash_attention,
|
||||
mock_reshape_cache):
|
||||
"""Test forward pass in PrefillNoCache state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
# layer.quant_method.apply.return_value = metadata
|
||||
print(self.layer_no_quant._v_scale_float)
|
||||
output = self.impl_swa.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_reshape_cache.assert_called_once()
|
||||
mock_flash_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_flash_attention_qlens')
|
||||
def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in PrefillCacheHit state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_flash_attention_qlens.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention')
|
||||
def test_forward_decode_only(self, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in DecodeOnly state"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
metadata.seq_lens = torch.tensor([10] * 10)
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 100
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
|
||||
64), 1)
|
||||
output = self.impl_swa.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
print(output.shape)
|
||||
mock_fused_infer_attention_score.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=False)
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('vllm_ascend.attention.attention_v1.vanilla_chunked_prefill')
|
||||
def test_forward_head_size_192(self, mock_vanilla_prefill,
|
||||
mock_npu_reshape_and_cache, mock_is_310p):
|
||||
"""Test forward pass when head_size is 192"""
|
||||
|
||||
self.impl.head_size = 192
|
||||
query = torch.randn(10, 8 * 192)
|
||||
key = torch.randn(10, 8 * 192)
|
||||
value = torch.randn(10, 8 * 192)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 192)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
mock_vanilla_prefill.return_value = MagicMock()
|
||||
|
||||
output = self.impl_192.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_vanilla_prefill.assert_called_once()
|
||||
assert output.shape == (10, 8 * 192)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
def test_forward_normal_v1_situation(self, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache):
|
||||
"""Test forward pass in normal V1 situation"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
@patch('torch_npu._npu_paged_attention_splitfuse')
|
||||
@patch('vllm_ascend.attention.attention_v1.is_310p', return_value=True)
|
||||
def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
|
||||
mock_npu_reshape_and_cache,
|
||||
mock_npu_format_cast):
|
||||
"""Test forward pass on 310P device"""
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
mock_npu_format_cast.return_value = metadata.attn_mask
|
||||
output = self.impl.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
|
||||
mock_paged_attention.assert_called_once()
|
||||
assert output.shape == (10, 8 * 64)
|
||||
|
||||
@patch('torch_npu._npu_reshape_and_cache')
|
||||
def test_forward_raise_error(self, mock_paged_attention):
|
||||
query = torch.randn(10, 8 * 64)
|
||||
key = torch.randn(10, 8 * 64)
|
||||
value = torch.randn(10, 8 * 64)
|
||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||
metadata = self.attn_metadata
|
||||
metadata.attn_mask = torch.randn(1, 1, 10, 10)
|
||||
metadata.query_lens = torch.tensor([10])
|
||||
metadata.seq_lens = torch.tensor([10])
|
||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||
metadata.num_actual_tokens = 10
|
||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||
layer = self.layer_no_quant
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.impl_error.forward(layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
metadata,
|
||||
trace_flag=False)
|
||||
631
tests/ut/attention/test_mla_v1.py
Normal file
631
tests/ut/attention/test_mla_v1.py
Normal file
@@ -0,0 +1,631 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import (AscendMLABackend,
|
||||
AscendMLADecodeMetadata,
|
||||
AscendMLAImpl, AscendMLAMetadata,
|
||||
AscendMLAMetadataBuilder,
|
||||
AscendMLAPrefillMetadata)
|
||||
|
||||
|
||||
class TestAscendMLABackend(TestBase):
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendMLABackend.get_name(), "ASCEND_MLA")
|
||||
|
||||
def test_get_metadata_cls(self):
|
||||
self.assertEqual(AscendMLABackend.get_metadata_cls(),
|
||||
AscendMLAMetadata)
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendMLABackend.get_builder_cls(),
|
||||
AscendMLAMetadataBuilder)
|
||||
|
||||
def test_get_kv_cache_shape(self):
|
||||
result = AscendMLABackend.get_kv_cache_shape(2, 4, 8, 128)
|
||||
self.assertEqual(result, (2, 4, 8, 128))
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
result = AscendMLABackend.get_impl_cls()
|
||||
self.assertEqual(result, AscendMLAImpl)
|
||||
|
||||
|
||||
class TestAscendMLAPrefillMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_prefill_metadata_default(self):
|
||||
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
|
||||
query_lens = [1, 2]
|
||||
seq_lens = [2, 2]
|
||||
context_lens = torch.tensor([1, 2])
|
||||
input_positions = torch.tensor([0, 1, 0, 1])
|
||||
query_start_loc = torch.tensor([0, 1, 3])
|
||||
block_table = torch.tensor([[0, 1], [2, 3]])
|
||||
max_query_len = 2
|
||||
max_seq_lens = 2
|
||||
|
||||
metadata = AscendMLAPrefillMetadata(attn_mask=attn_mask,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
context_lens=context_lens,
|
||||
input_positions=input_positions,
|
||||
query_start_loc=query_start_loc,
|
||||
block_table=block_table,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens)
|
||||
self.assertIs(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.context_lens, context_lens)
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertEqual(metadata.max_query_len, max_query_len)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertIsNone(metadata.chunked_context)
|
||||
|
||||
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
|
||||
cu_seq_lens = torch.tensor([0, 2, 4])
|
||||
starts = torch.tensor([0, 2])
|
||||
seq_tot = [2, 2]
|
||||
max_seq_lens = [2, 2]
|
||||
workspace = torch.randn(2, 4)
|
||||
chunk_seq_lens = torch.tensor([2, 2])
|
||||
|
||||
chunked_context = AscendMLAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
starts=starts,
|
||||
seq_tot=seq_tot,
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens)
|
||||
|
||||
metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
query_lens=[1, 2],
|
||||
seq_lens=[2, 2],
|
||||
context_lens=torch.tensor([1, 2]),
|
||||
input_positions=torch.tensor([0, 1, 0, 1]),
|
||||
query_start_loc=torch.tensor([0, 1, 3]),
|
||||
block_table=torch.tensor([[0, 1], [2, 3]]),
|
||||
max_query_len=2,
|
||||
max_seq_lens=2,
|
||||
chunked_context=chunked_context)
|
||||
|
||||
self.assertIsNotNone(metadata.chunked_context)
|
||||
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.starts, starts)
|
||||
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
|
||||
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.workspace, workspace)
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||
|
||||
|
||||
class TestAscendMLADecodeMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_decode_metadata_default(self):
|
||||
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
|
||||
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
|
||||
seq_lens = torch.tensor([[2], [3]])
|
||||
max_seq_lens = 4
|
||||
seq_lens_list = [2, 3]
|
||||
attn_mask = None
|
||||
|
||||
metadata = AscendMLADecodeMetadata(input_positions, block_table,
|
||||
seq_lens, max_seq_lens,
|
||||
seq_lens_list, attn_mask)
|
||||
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertIs(metadata.seq_lens, seq_lens)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
|
||||
self.assertIsNone(attn_mask)
|
||||
|
||||
|
||||
class TestAscendMLAMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_default(self):
|
||||
num_actual_tokens = 100
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
query_start_loc = torch.tensor([1, 2, 3, 4])
|
||||
seq_lens = [30, 50]
|
||||
block_tables = torch.randint(0, 100, (100, 4))
|
||||
|
||||
num_decodes = 4
|
||||
num_decode_tokens = 8
|
||||
num_prefills = 8
|
||||
|
||||
num_input_tokens = 2
|
||||
|
||||
query_lens = None
|
||||
head_dim = None
|
||||
attn_mask = None
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode = None
|
||||
prefill = None
|
||||
|
||||
metadata = AscendMLAMetadata(num_actual_tokens, slot_mapping,
|
||||
query_start_loc, seq_lens, block_tables,
|
||||
num_decodes, num_decode_tokens,
|
||||
num_prefills, num_input_tokens,
|
||||
query_lens, head_dim, attn_mask,
|
||||
attn_state, decode, prefill)
|
||||
|
||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.block_tables, block_tables)
|
||||
self.assertEqual(metadata.num_decodes, num_decodes)
|
||||
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
|
||||
self.assertEqual(metadata.num_prefills, num_prefills)
|
||||
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.head_dim, head_dim)
|
||||
self.assertEqual(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.attn_state, attn_state)
|
||||
self.assertEqual(metadata.decode, decode)
|
||||
self.assertEqual(metadata.prefill, prefill)
|
||||
|
||||
|
||||
class TestAscendMLAMetadataBuilder(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_builder_default(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
ascend_config = MagicMock()
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||
|
||||
def test_reorder_batch(self):
|
||||
ascend_config = MagicMock()
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
with patch("vllm_ascend.attention.mla_v1.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLAMetadataBuilder(mock_vllm_config, mock_device)
|
||||
builder.decode_threshold = 1
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
|
||||
scheduler_output.scheduled_spec_decode_tokens = {
|
||||
0: [],
|
||||
1: [1],
|
||||
2: [],
|
||||
3: []
|
||||
}
|
||||
|
||||
input_batch.swap_states = MagicMock()
|
||||
|
||||
modified = builder.reorder_batch(input_batch, scheduler_output)
|
||||
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
|
||||
class TestAscendMLAImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm_ascend.attention.mla_v1.get_current_vllm_config")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_ascend_config")
|
||||
def setUp(self, ascend_config, get_current_vllm_config, mock_get_tp_size,
|
||||
mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
vllm_config = MagicMock()
|
||||
speculative_config = MagicMock()
|
||||
model_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
model_config.dtype = torch.float16
|
||||
vllm_config.model_config = model_config
|
||||
get_current_vllm_config.return_value = vllm_config
|
||||
|
||||
num_heads = 256
|
||||
head_size = 1024
|
||||
scale = 0.1
|
||||
num_kv_heads = 8
|
||||
kv_cache_dtype = "auto"
|
||||
|
||||
kv_a_layernorm = MagicMock()
|
||||
kv_a_layernorm.weight = torch.randn(96)
|
||||
kv_a_layernorm.variance_epsilon = 1e-6
|
||||
kwargs = {
|
||||
"q_lora_rank": 64,
|
||||
"kv_lora_rank": 32,
|
||||
"qk_nope_head_dim": 64,
|
||||
"qk_rope_head_dim": 32,
|
||||
"qk_head_dim": 96,
|
||||
"v_head_dim": 128,
|
||||
"rotary_emb": MagicMock(),
|
||||
"q_proj": MagicMock(),
|
||||
"kv_b_proj": MagicMock(),
|
||||
"o_proj": MagicMock(),
|
||||
"kv_a_proj_with_mqa": MagicMock(),
|
||||
"kv_a_layernorm": kv_a_layernorm,
|
||||
}
|
||||
|
||||
self.impl = AscendMLAImpl(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
blocksparse_params=None,
|
||||
logits_soft_cap=None,
|
||||
attn_type=None,
|
||||
kv_sharing_target_layer_name=None,
|
||||
**kwargs)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.impl.num_heads, 256)
|
||||
self.assertEqual(self.impl.head_size, 1024)
|
||||
self.assertEqual(self.impl.scale, 0.1)
|
||||
self.assertEqual(self.impl.num_kv_heads, 8)
|
||||
self.assertEqual(self.impl.kv_cache_dtype, "auto")
|
||||
self.assertEqual(self.impl.q_lora_rank, 64)
|
||||
self.assertEqual(self.impl.kv_lora_rank, 32)
|
||||
self.assertEqual(self.impl.qk_nope_head_dim, 64)
|
||||
self.assertEqual(self.impl.qk_rope_head_dim, 32)
|
||||
self.assertEqual(self.impl.qk_head_dim, 96)
|
||||
self.assertEqual(self.impl.v_head_dim, 128)
|
||||
self.assertIsNotNone(self.impl.rotary_emb)
|
||||
self.assertIsNotNone(self.impl.q_proj)
|
||||
self.assertIsNotNone(self.impl.kv_b_proj)
|
||||
self.assertIsNotNone(self.impl.o_proj)
|
||||
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
|
||||
self.assertIsNotNone(self.impl.kv_a_layernorm)
|
||||
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
||||
self.assertEqual(self.impl.tp_size, 2)
|
||||
|
||||
def test_v_up_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
|
||||
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
|
||||
self.impl.W_UV = torch.randn(self.impl.num_heads,
|
||||
self.impl.kv_lora_rank,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._v_up_proj(x)
|
||||
|
||||
self.assertEqual(result.shape[0], batch_size)
|
||||
self.assertEqual(result.shape[1],
|
||||
self.impl.num_heads * self.impl.v_head_dim)
|
||||
|
||||
def test_q_proj_and_k_up_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
|
||||
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
self.impl.q_proj.return_value = (q_proj_output, )
|
||||
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
|
||||
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim,
|
||||
self.impl.kv_lora_rank)
|
||||
result = self.impl._q_proj_and_k_up_proj(x)
|
||||
ql_nope, q_pe = result
|
||||
self.assertEqual(ql_nope.shape[0], batch_size)
|
||||
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
|
||||
self.assertEqual(q_pe.shape[0], batch_size)
|
||||
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
layer = MagicMock(spec=LinearBase)
|
||||
layer.input_size_per_partition = 10
|
||||
quant_method = MagicMock()
|
||||
apply = MagicMock()
|
||||
quant_method.apply = apply
|
||||
layer.quant_method = quant_method
|
||||
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
|
||||
self.impl.v_head_dim)
|
||||
shape_1 = self.impl.kv_lora_rank
|
||||
layer.weight = torch.randn(shape_0, shape_1)
|
||||
self.impl.kv_b_proj = layer
|
||||
apply.return_value = layer.weight.T
|
||||
self.impl.process_weights_after_loading(torch.bfloat16)
|
||||
|
||||
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
|
||||
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
|
||||
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
|
||||
|
||||
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
|
||||
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
|
||||
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
|
||||
|
||||
def test_compute_prefill_context_none(self):
|
||||
batch_size = 4
|
||||
kv_cache = torch.randn(10, 1, 1, 192)
|
||||
query = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
metadata = MagicMock()
|
||||
metadata.prefill = None
|
||||
prefix_out = torch.randn(2, 16, 128)
|
||||
prefix_lse = torch.randn(2, 16, 8)
|
||||
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
||||
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
||||
|
||||
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
||||
32, metadata, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
self.assertTrue(torch.equal(prefix_out, out))
|
||||
self.assertTrue(torch.equal(prefix_lse, lse))
|
||||
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context(self, mock_ring, mock_load):
|
||||
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
|
||||
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
|
||||
latent_kv_dim = self.impl.kv_lora_rank
|
||||
num_blocks, block_size = 100, 20
|
||||
query = torch.randn(S, N, D)
|
||||
q_nope = query[..., :self.impl.qk_nope_head_dim]
|
||||
q_pe = query[..., self.impl.qk_nope_head_dim:]
|
||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||
kv_cache = [kv_cache_0, kv_cache_1]
|
||||
prefix_out = torch.randn(S, N, 128)
|
||||
prefix_lse = torch.randn(S, N)
|
||||
|
||||
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
|
||||
|
||||
chunk_ctx = MagicMock()
|
||||
chunk_ctx.seq_tot = [8]
|
||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||
chunk_ctx.starts = [torch.tensor([0])]
|
||||
|
||||
prefill_meta = MagicMock()
|
||||
prefill_meta.chunked_context = chunk_ctx
|
||||
prefill_meta.query_lens = [8]
|
||||
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
|
||||
|
||||
meta = MagicMock()
|
||||
meta.prefill = prefill_meta
|
||||
self.impl.prefill_mask = torch.triu(
|
||||
torch.ones(512, 512, device=q_nope.device, dtype=q_nope.dtype), 1)
|
||||
|
||||
out, lse = self.impl._compute_prefill_context(q_nope, q_pe, kv_cache,
|
||||
32, meta, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
mock_load.assert_called_once()
|
||||
mock_ring.assert_called_once()
|
||||
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.AscendMLAImpl._v_up_proj")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode_without_graph(self,
|
||||
mock_npu_fused_infer_attention_score,
|
||||
mock_up_proj):
|
||||
num_tokens = 100
|
||||
block_size = 4
|
||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
k_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim)
|
||||
k_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
metadata = MagicMock()
|
||||
metadata.decode = MagicMock()
|
||||
metadata.decode.block_table = MagicMock()
|
||||
metadata.decode.seq_lens = 10
|
||||
mock_npu_fused_infer_attention_score.return_value = [
|
||||
torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank), None
|
||||
]
|
||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe,
|
||||
block_size, metadata)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
self.assertEqual(result.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
|
||||
def test_mla_preprocess(self, magic_npu_fetch):
|
||||
magic_npu_fetch.return_value = MagicMock()
|
||||
batch_size = 4
|
||||
seq_len = 8
|
||||
hidden_size = 1024
|
||||
hidden_states = torch.randn(batch_size * seq_len, hidden_size)
|
||||
|
||||
kv_cache = MagicMock()
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.num_decodes = 2
|
||||
attn_metadata.num_prefills = 2
|
||||
attn_metadata.num_decode_tokens = 2
|
||||
attn_metadata.num_actual_tokens = 4
|
||||
num_prefill_tokens = 2
|
||||
attn_metadata.slot_mapping = torch.arange(4)
|
||||
attn_metadata.decode.cos = torch.randn(2, 64)
|
||||
attn_metadata.decode.sin = torch.randn(2, 64)
|
||||
attn_metadata.prefill.cos = torch.randn(2, 64)
|
||||
attn_metadata.prefill.sin = torch.randn(2, 64)
|
||||
|
||||
self.impl.q_a_proj = MagicMock()
|
||||
self.impl.q_a_layernorm = MagicMock()
|
||||
self.impl.q_a_layernorm.return_value = torch.randn(
|
||||
attn_metadata.num_actual_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
self.impl.kv_a_proj_with_mqa = MagicMock()
|
||||
self.impl.kv_a_proj_with_mqa.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)
|
||||
]
|
||||
self.impl.q_proj = MagicMock()
|
||||
self.impl.q_proj.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
]
|
||||
self.impl.kv_b_proj = MagicMock()
|
||||
self.impl.kv_b_proj.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.v_head_dim + self.impl.qk_nope_head_dim)
|
||||
]
|
||||
self.impl.rope_single = MagicMock(side_effect=lambda x, cos, sin: x)
|
||||
self.impl.exec_kv_decode = MagicMock()
|
||||
self.impl.exec_kv_decode.return_value = [MagicMock(), MagicMock()]
|
||||
self.impl.exec_kv_prefill = MagicMock()
|
||||
self.impl.exec_kv_prefill.return_value = [
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim),
|
||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
]
|
||||
self.impl._q_proj_and_k_up_proj = MagicMock()
|
||||
self.impl._q_proj_and_k_up_proj.return_value = [
|
||||
MagicMock(), MagicMock()
|
||||
]
|
||||
self.impl.num_kv_heads = self.impl.num_heads
|
||||
|
||||
decode_res, prefill_res = self.impl._mla_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
|
||||
|
||||
self.assertIsNotNone(decode_res)
|
||||
self.assertIsNotNone(prefill_res)
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv_prefill(self, mock_kv_rmsnorm_rope_cache):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
|
||||
kv_no_split = torch.randn(B, N, D)
|
||||
self.impl.enable_kv_nz = None
|
||||
self.impl.kv_a_layernorm.weight = MagicMock()
|
||||
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
|
||||
cos = MagicMock()
|
||||
sin = MagicMock()
|
||||
slots = MagicMock()
|
||||
kv_cache = [MagicMock(), MagicMock()]
|
||||
|
||||
mock_kv_rmsnorm_rope_cache.return_value = [
|
||||
None, None,
|
||||
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
|
||||
torch.randn(B, N, 1, self.impl.kv_lora_rank)
|
||||
]
|
||||
|
||||
k_pe, k_nope = self.impl.exec_kv_prefill(kv_no_split, cos, sin,
|
||||
kv_cache, slots)
|
||||
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv_decode(self, mock_kv_rmsnorm_rope_cache):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
D = self.impl.kv_lora_rank + self.impl.qk_rope_head_dim
|
||||
kv_no_split = torch.randn(B, N, D)
|
||||
self.impl.enable_kv_nz = None
|
||||
self.impl.kv_a_layernorm.weight = MagicMock()
|
||||
self.impl.kv_a_layernorm.variance_epsilon = MagicMock()
|
||||
cos = MagicMock()
|
||||
sin = MagicMock()
|
||||
slots = MagicMock()
|
||||
kv_cache = [MagicMock(), MagicMock()]
|
||||
|
||||
mock_kv_rmsnorm_rope_cache.return_value = [
|
||||
torch.randn(B, N, 1, self.impl.qk_rope_head_dim),
|
||||
torch.randn(B, N, 1, self.impl.kv_lora_rank), None, None
|
||||
]
|
||||
|
||||
k_pe, k_nope = self.impl.exec_kv_decode(kv_no_split, cos, sin,
|
||||
kv_cache, slots)
|
||||
|
||||
self.assertEqual(k_pe.shape[-1], self.impl.qk_rope_head_dim)
|
||||
self.assertEqual(k_nope.shape[-1], self.impl.kv_lora_rank)
|
||||
|
||||
@patch("torch.npu.stream")
|
||||
@patch("vllm_ascend.attention.mla_v1.get_multistream_comm_context")
|
||||
@patch("torch_npu.npu_fused_infer_attention_score")
|
||||
def test_forward_decode(self, mock_npu_fused_infer_attention_score,
|
||||
mock_get_multistream_comm_context,
|
||||
mock_npu_stream):
|
||||
B = 2
|
||||
N = self.impl.num_kv_heads
|
||||
BS = 100
|
||||
HD = self.impl.v_head_dim
|
||||
self.impl.kv_lora_rank = 256
|
||||
self.impl.spec_token_num = 1
|
||||
self.impl._v_up_proj = MagicMock()
|
||||
self.impl._v_up_proj.return_value = torch.randn(B, N, HD)
|
||||
q_nope = torch.randn(B, N, self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(B, N, self.impl.qk_rope_head_dim)
|
||||
k_nope = torch.randn(BS, N, self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(BS, N, self.impl.qk_rope_head_dim)
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.SpecDecoding
|
||||
attn_metadata.decode = MagicMock()
|
||||
attn_metadata.decode.actual_seq_lengths_q = MagicMock()
|
||||
attn_metadata.decode.seq_lens_list = MagicMock()
|
||||
self.impl.enable_kv_nz = True
|
||||
|
||||
mock_npu_fused_infer_attention_score.return_value = [
|
||||
torch.randn(B, N, self.impl.kv_lora_rank), None
|
||||
]
|
||||
mock_get_multistream_comm_context.return_value = None
|
||||
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
|
||||
self.impl.enable_kv_nz = False
|
||||
attn_metadata.attn_state = None
|
||||
mock_return_value = MagicMock()
|
||||
mock_get_multistream_comm_context.return_value = mock_return_value
|
||||
mock_return_value.before_comm_event = MagicMock()
|
||||
mock_return_value.comm_stream = MagicMock()
|
||||
mock_npu_stream.return_value = MagicMock()
|
||||
|
||||
result = self.impl._forward_decode(q_nope, q_pe, k_nope, k_pe, BS,
|
||||
attn_metadata)
|
||||
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], HD)
|
||||
44
tests/ut/base.py
Normal file
44
tests/ut/base.py
Normal file
@@ -0,0 +1,44 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm_ascend.utils import adapt_patch, register_ascend_customop
|
||||
|
||||
|
||||
class TestBase(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# adapt patch by default.
|
||||
adapt_patch(True)
|
||||
adapt_patch()
|
||||
register_ascend_customop()
|
||||
super().setUp()
|
||||
super(TestBase, self).__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class PytestBase:
|
||||
"""Base class for pytest-based tests.
|
||||
because pytest mocker and parametrize usage are not compatible with unittest.
|
||||
so we need to use a separate base class for pytest tests.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
adapt_patch(True)
|
||||
adapt_patch()
|
||||
register_ascend_customop()
|
||||
26
tests/ut/conftest.py
Normal file
26
tests/ut/conftest.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||
from vllm_ascend.utils import register_ascend_customop
|
||||
|
||||
adapt_patch()
|
||||
adapt_patch(True)
|
||||
|
||||
# register Ascend CustomOp here because uts will use this
|
||||
register_ascend_customop()
|
||||
167
tests/ut/core/test_schedule_config.py
Normal file
167
tests/ut/core/test_schedule_config.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.core.schedule_config import AscendSchedulerConfig
|
||||
|
||||
|
||||
class TestAscendSchedulerConfig(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.basic_scheduler_config = SchedulerConfig(
|
||||
max_num_batched_tokens=8192,
|
||||
max_model_len=8192,
|
||||
is_multimodal_model=False,
|
||||
send_delta_data=False,
|
||||
scheduler_delay_factor=0,
|
||||
)
|
||||
|
||||
def test_initialize_from_config_with_default(self):
|
||||
# No additional config given, check the default value here.
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config, {})
|
||||
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
||||
self.assertEqual(ascend_config.policy, "fcfs")
|
||||
self.assertEqual(ascend_config.num_scheduler_steps, 1)
|
||||
self.assertEqual(ascend_config.scheduler_cls,
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
||||
self.assertEqual(ascend_config.encoder_cache_size, 8192)
|
||||
|
||||
def test_initialize_from_config_with_override(self):
|
||||
# test override
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
enable_chunked_prefill=False,
|
||||
policy="fcfs",
|
||||
num_scheduler_steps=1,
|
||||
scheduler_cls="vllm_ascend.core.scheduler.AscendScheduler",
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertEqual(ascend_config.enable_chunked_prefill, False)
|
||||
self.assertEqual(ascend_config.policy, "fcfs")
|
||||
self.assertEqual(ascend_config.num_scheduler_steps, 1)
|
||||
self.assertEqual(ascend_config.scheduler_cls,
|
||||
"vllm_ascend.core.scheduler.AscendScheduler")
|
||||
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
|
||||
self.assertEqual(ascend_config.encoder_cache_size, 2048)
|
||||
|
||||
def test_not_implemented_policy(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
policy="custom_policy",
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"currently AscendScheduler only supports fcfs policy",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_not_implemented_multimodal(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
SchedulerConfig(is_multimodal_model=True), {})
|
||||
self.assertIn("currently AscendScheduler only supports LLM models",
|
||||
str(context.exception))
|
||||
|
||||
def test_not_implemented_multi_step(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
num_scheduler_steps=2,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"currently AscendScheduler doesn't support multi-step",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_not_implemented_send_delta_data(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
send_delta_data=True,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"currently AscendScheduler doesn't support send_delta_data",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_not_implemented_delay_factor(self):
|
||||
with self.assertRaises(NotImplementedError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
delay_factor=1,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"currently AscendScheduler doesn't support scheduler_delay_factor",
|
||||
str(context.exception),
|
||||
)
|
||||
|
||||
def test_no_override(self):
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config, {})
|
||||
self.assertEqual(ascend_config.max_num_encoder_input_tokens, 8192)
|
||||
self.assertEqual(ascend_config.encoder_cache_size, 8192)
|
||||
|
||||
def test_valid_config_with_chunked_prefill(self):
|
||||
ascend_config = AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
enable_chunked_prefill=True,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=4096,
|
||||
),
|
||||
)
|
||||
self.assertEqual(ascend_config.max_num_batched_tokens, 2048)
|
||||
self.assertEqual(ascend_config.max_model_len, 4096)
|
||||
self.assertTrue(ascend_config.enable_chunked_prefill)
|
||||
|
||||
def test_invalid_config_without_chunked_prefill(self):
|
||||
with self.assertRaises(ValueError) as context:
|
||||
AscendSchedulerConfig.initialize_from_config(
|
||||
self.basic_scheduler_config,
|
||||
AscendSchedulerConfig(
|
||||
enable_chunked_prefill=False,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=4096,
|
||||
),
|
||||
)
|
||||
self.assertIn(
|
||||
"Ascend scheduler is enabled without chunked prefill feature",
|
||||
str(context.exception),
|
||||
)
|
||||
self.assertIn("max_num_batched_tokens (2048)", str(context.exception))
|
||||
self.assertIn("max_model_len (4096)", str(context.exception))
|
||||
898
tests/ut/core/test_scheduler.py
Normal file
898
tests/ut/core/test_scheduler.py
Normal file
@@ -0,0 +1,898 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.core.scheduler import AscendScheduler
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
from vllm.v1.outputs import DraftTokenIds
|
||||
else:
|
||||
DraftTokenIds = None
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
MODEL = "Qwen3-0.6B"
|
||||
ENABLE_PREFIX_CACHING = None
|
||||
PROMPT_LOGPROBS = None
|
||||
ENABLE_CHUNKED_PREFILL = False
|
||||
MAX_NUM_BATCHED_TOKENS = 10000
|
||||
LONG_PREFILL_TOKEN_THRESHOLD = 0
|
||||
NUM_SPECULATIVE_TOKENS = None
|
||||
MAX_NUM_SEQS = 16
|
||||
|
||||
|
||||
def create_requests(
|
||||
num_requests: int,
|
||||
num_tokens: int = 10,
|
||||
mm_positions: Optional[list[PlaceholderRange]] = None,
|
||||
max_tokens: int = 16,
|
||||
stop_token_ids: Optional[list[int]] = None,
|
||||
block_size: int = 3,
|
||||
hash_fn=hash,
|
||||
):
|
||||
init_none_hash(hash_fn)
|
||||
prompt_logprobs = PROMPT_LOGPROBS
|
||||
sampling_params = SamplingParams(ignore_eos=False,
|
||||
max_tokens=max_tokens,
|
||||
stop_token_ids=stop_token_ids,
|
||||
prompt_logprobs=prompt_logprobs)
|
||||
requests = []
|
||||
for i in range(num_requests):
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
else:
|
||||
request = Request(request_id=f"{i}",
|
||||
prompt_token_ids=[i] * num_tokens,
|
||||
sampling_params=sampling_params,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
pooling_params=None,
|
||||
block_hasher=get_request_block_hasher(
|
||||
block_size, hash_fn))
|
||||
requests.append(request)
|
||||
return requests
|
||||
|
||||
|
||||
def make_output(scheduler):
|
||||
req_ids = [req.request_id for req in scheduler.running]
|
||||
req_id_to_index = {
|
||||
req.request_id: i
|
||||
for i, req in enumerate(scheduler.running)
|
||||
}
|
||||
sampled_token_ids = [[1000]] * len(scheduler.running)
|
||||
logprobs = None
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
else:
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=logprobs,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
return modelrunner_output
|
||||
|
||||
|
||||
class TestAscendScheduler(TestBase):
|
||||
|
||||
@patch("vllm.config.ModelConfig.__post_init__", MagicMock())
|
||||
@patch("vllm.config.VllmConfig.__post_init__", MagicMock())
|
||||
@patch('vllm.v1.core.sched.scheduler.compute_encoder_budget')
|
||||
def create_scheduler(self, mock_compute_encoder_budget):
|
||||
mock_compute_encoder_budget.return_value = [10, 20]
|
||||
use_kv_connector = False
|
||||
block_size = 16
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=16,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKENS,
|
||||
long_prefill_token_threshold=LONG_PREFILL_TOKEN_THRESHOLD,
|
||||
disable_chunked_mm_input=False,
|
||||
enable_chunked_prefill=ENABLE_CHUNKED_PREFILL,
|
||||
max_num_batched_tokens=MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
|
||||
scheduler_config.max_num_encoder_input_tokens = 10000
|
||||
scheduler_config.encoder_cache_size = 10000
|
||||
scheduler_config.chunked_prefill_enabled = False
|
||||
|
||||
model_config = ModelConfig(
|
||||
model=MODEL,
|
||||
task="auto",
|
||||
tokenizer=MODEL,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
max_model_len=MAX_NUM_BATCHED_TOKENS,
|
||||
)
|
||||
model_config.pooler_config = MagicMock()
|
||||
model_config.multimodal_config = MagicMock()
|
||||
model_config.hf_config = MagicMock()
|
||||
model_config.hf_config.is_encoder_decoder = False
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache: Dict[str,
|
||||
Any] = ({} if ENABLE_PREFIX_CACHING is None else {
|
||||
'enable_prefix_caching':
|
||||
ENABLE_PREFIX_CACHING
|
||||
})
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
**kwargs_cache,
|
||||
)
|
||||
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="SharedStorageConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={"shared_storage_path": "local_storage"},
|
||||
) if use_kv_connector else None
|
||||
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
if NUM_SPECULATIVE_TOKENS is not None:
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="ngram", num_speculative_tokens=NUM_SPECULATIVE_TOKENS)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
speculative_config=speculative_config,
|
||||
)
|
||||
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10000, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1,
|
||||
torch.float32, False))
|
||||
],
|
||||
)
|
||||
cache_config.num_gpu_blocks = 10000
|
||||
|
||||
scheduler = AscendScheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=MagicMock(spec=StructuredOutputManager),
|
||||
)
|
||||
|
||||
should_advance = MagicMock()
|
||||
should_advance.return_value = False
|
||||
scheduler.structured_output_manager.should_advance = should_advance
|
||||
|
||||
return scheduler
|
||||
|
||||
def test_add_requests(self):
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
self.assertIn(request.request_id, scheduler.requests)
|
||||
self.assertEqual(len(scheduler.waiting), i + 1)
|
||||
|
||||
def test_finish_request(self):
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_ABORTED)
|
||||
self.assertNotIn(request.request_id, scheduler.requests)
|
||||
self.assertEqual(len(scheduler.waiting), 9 - i)
|
||||
|
||||
def test_get_num_unfinished_requests(self):
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.finish_requests(request.request_id,
|
||||
RequestStatus.FINISHED_STOPPED)
|
||||
self.assertEqual(scheduler.get_num_unfinished_requests(),
|
||||
len(requests) - i - 1)
|
||||
|
||||
def test_schedule(self):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
self.assertEqual(num_tokens,
|
||||
len(requests[int(req_id)].prompt_token_ids))
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), len(requests))
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_schedule_enable_prefix_caching(self):
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
global ENABLE_PREFIX_CACHING
|
||||
ENABLE_PREFIX_CACHING = True
|
||||
global PROMPT_LOGPROBS
|
||||
PROMPT_LOGPROBS = 5
|
||||
scheduler = self.create_scheduler()
|
||||
scheduler.scheduler_config.chunked_prefill_enabled = False
|
||||
requests = create_requests(num_requests=10)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Test initial scheduling
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.scheduled_cached_reqs.num_reqs, 0)
|
||||
self.assertEqual(len(output.finished_req_ids), 0)
|
||||
# Verify all requests are scheduled.
|
||||
for req_id, num_tokens in output.num_scheduled_tokens.items():
|
||||
self.assertEqual(num_tokens,
|
||||
len(requests[int(req_id)].prompt_token_ids))
|
||||
|
||||
# Verify requests moved from waiting to running
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), len(requests))
|
||||
for i, request in enumerate(requests):
|
||||
self.assertEqual(scheduler.running[i], request)
|
||||
|
||||
def test_stop_via_update_from_output(self):
|
||||
"""Test stopping behavior through update_from_output"""
|
||||
global NUM_SPECULATIVE_TOKENS
|
||||
NUM_SPECULATIVE_TOKENS = 1
|
||||
scheduler = self.create_scheduler()
|
||||
|
||||
# Test case 1: Stop on EOS token
|
||||
requests = create_requests(num_requests=2, max_tokens=10)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [
|
||||
10, 11
|
||||
]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 1,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID], [
|
||||
10, 11
|
||||
]], # First request hits EOS, second continues
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped, second continues
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertEqual(scheduler.running[0].request_id,
|
||||
requests[1].request_id)
|
||||
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
|
||||
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
|
||||
self.assertEqual(list(requests[0].output_token_ids), [EOS_TOKEN_ID])
|
||||
self.assertEqual(list(requests[1].output_token_ids), [10, 11])
|
||||
|
||||
# Test case 2: Stop on custom stop token
|
||||
NUM_SPECULATIVE_TOKENS = 2
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=2,
|
||||
max_tokens=10,
|
||||
stop_token_ids=[42, 43])
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 2
|
||||
},
|
||||
total_num_scheduled_tokens=5,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped on custom token
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertEqual(scheduler.running[0].request_id,
|
||||
requests[1].request_id)
|
||||
self.assertEqual(requests[0].status, RequestStatus.FINISHED_STOPPED)
|
||||
self.assertEqual(requests[0].stop_reason, 42)
|
||||
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
|
||||
self.assertEqual(list(requests[0].output_token_ids), [10, 42])
|
||||
self.assertEqual(list(requests[1].output_token_ids), [13, 14])
|
||||
|
||||
# Test case 3: Stop on max tokens
|
||||
NUM_SPECULATIVE_TOKENS = 2
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=2, max_tokens=2)
|
||||
for req in requests:
|
||||
req.num_computed_tokens = req.num_tokens
|
||||
scheduler.requests[req.request_id] = req
|
||||
scheduler.running.append(req)
|
||||
req.status = RequestStatus.RUNNING
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={
|
||||
requests[0].request_id: 3,
|
||||
requests[1].request_id: 1
|
||||
},
|
||||
total_num_scheduled_tokens=4,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: []
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[req.request_id for req in requests],
|
||||
req_id_to_index={
|
||||
req.request_id: i
|
||||
for i, req in enumerate(requests)
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify first request stopped due to length
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertEqual(scheduler.running[0].request_id,
|
||||
requests[1].request_id)
|
||||
self.assertEqual(requests[0].status,
|
||||
RequestStatus.FINISHED_LENGTH_CAPPED)
|
||||
self.assertIn(requests[0].request_id, scheduler.finished_req_ids)
|
||||
self.assertEqual(list(requests[0].output_token_ids), [10, 11])
|
||||
self.assertEqual(list(requests[1].output_token_ids), [13])
|
||||
|
||||
# Test case 4: Ignore EOS flag
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=1, max_tokens=10)
|
||||
requests[0].sampling_params.ignore_eos = True
|
||||
requests[0].num_computed_tokens = requests[0].num_tokens
|
||||
scheduler.requests[requests[0].request_id] = requests[0]
|
||||
scheduler.running.append(requests[0])
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_input_ids=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
else:
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=[],
|
||||
scheduled_cached_reqs=[],
|
||||
num_scheduled_tokens={requests[0].request_id: 3},
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={
|
||||
requests[0].request_id: [EOS_TOKEN_ID, 10]
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
grammar_bitmask=None)
|
||||
model_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output, model_output)
|
||||
|
||||
# Verify request continues past EOS
|
||||
self.assertEqual(len(scheduler.running), 1)
|
||||
self.assertFalse(requests[0].is_finished())
|
||||
self.assertEqual(list(requests[0].output_token_ids),
|
||||
[EOS_TOKEN_ID, 10, 11])
|
||||
|
||||
def test_schedule_concurrent_batches(self):
|
||||
global MAX_NUM_BATCHED_TOKENS
|
||||
global ENABLE_PREFIX_CACHING
|
||||
global ENABLE_CHUNKED_PREFILL
|
||||
global MAX_NUM_SEQS
|
||||
global PROMPT_LOGPROBS
|
||||
ENABLE_PREFIX_CACHING = None
|
||||
MAX_NUM_BATCHED_TOKENS = 1024
|
||||
MAX_NUM_SEQS = 2
|
||||
ENABLE_CHUNKED_PREFILL = True
|
||||
PROMPT_LOGPROBS = None
|
||||
|
||||
enable_prefix_caching_list = [None, True]
|
||||
prompt_logprobs_list = [None, 5]
|
||||
|
||||
for i in range(len(enable_prefix_caching_list)):
|
||||
ENABLE_PREFIX_CACHING = enable_prefix_caching_list[i]
|
||||
PROMPT_LOGPROBS = prompt_logprobs_list[i]
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(
|
||||
num_requests=2,
|
||||
num_tokens=512,
|
||||
)
|
||||
|
||||
# Schedule the first request.
|
||||
scheduler.add_request(requests[0])
|
||||
scheduler_output0 = scheduler.schedule()
|
||||
self.assertEqual(len(scheduler_output0.scheduled_new_reqs), 1)
|
||||
self.assertEqual(
|
||||
scheduler_output0.num_scheduled_tokens[requests[0].request_id],
|
||||
512)
|
||||
|
||||
# The first request is still running, so only schedule the second request.
|
||||
scheduler.add_request(requests[1])
|
||||
scheduler_output1 = scheduler.schedule()
|
||||
self.assertEqual(len(scheduler_output1.scheduled_new_reqs), 1)
|
||||
self.assertEqual(
|
||||
scheduler_output1.num_scheduled_tokens[requests[1].request_id],
|
||||
512)
|
||||
|
||||
# Model output of the first request.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output0,
|
||||
model_runner_output)
|
||||
|
||||
# Schedule the next step.
|
||||
# The first request can be scheduled again while the second
|
||||
# request is still running.
|
||||
scheduler.schedule()
|
||||
# Model output of the second request.
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
scheduler.update_from_output(scheduler_output1,
|
||||
model_runner_output)
|
||||
|
||||
def test_schedule_spec_decoding_stats(self):
|
||||
"""Test scheduling behavior with speculative decoding.
|
||||
|
||||
This test verifies that:
|
||||
1. Speculated tokens get scheduled correctly
|
||||
2. Spec decoding stats properly count number of draft and accepted tokens
|
||||
"""
|
||||
spec_tokens_list: List[List[List[int]]] = [[[1, 2, 3]], [[1, 2, 3]],
|
||||
[[1, 2], [3]], [[1]], [[]],
|
||||
[[1, 2, 3], [4, 5, 6]]]
|
||||
output_tokens_list: List[List[List[int]]] = [[[1, 2, 3, 4]], [[1, 5]],
|
||||
[[1, 2, 5], [3, 4]],
|
||||
[[1, 2]], [[5]],
|
||||
[[1, 2, 7], [4, 8]]]
|
||||
expected_list: List[Tuple[int, int,
|
||||
int, List[int]]] = [(1, 3, 3, [1, 1, 1]),
|
||||
(1, 3, 1, [1, 0, 0]),
|
||||
(2, 3, 3, [2, 1]),
|
||||
(1, 1, 1, [1]),
|
||||
(0, 0, 0, [0]),
|
||||
(2, 6, 3, [2, 1, 0])]
|
||||
|
||||
global NUM_SPECULATIVE_TOKENS
|
||||
for idx in range(len(spec_tokens_list)):
|
||||
spec_tokens = spec_tokens_list[idx]
|
||||
output_tokens = output_tokens_list[idx]
|
||||
expected = expected_list[idx]
|
||||
num_spec_tokens = max(1, max(len(t) for t in spec_tokens))
|
||||
NUM_SPECULATIVE_TOKENS = num_spec_tokens
|
||||
scheduler = self.create_scheduler()
|
||||
requests = create_requests(num_requests=len(spec_tokens),
|
||||
num_tokens=1)
|
||||
req_ids = []
|
||||
req_to_index = {}
|
||||
for i, request in enumerate(requests):
|
||||
scheduler.add_request(request)
|
||||
req_ids.append(request.request_id)
|
||||
req_to_index[request.request_id] = i
|
||||
|
||||
# Schedule a decode, which will also draft speculative tokens
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), len(requests))
|
||||
self.assertEqual(output.total_num_scheduled_tokens, len(requests))
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
self.assertEqual(output.num_scheduled_tokens[req_id], 1)
|
||||
self.assertNotIn(req_id, output.scheduled_spec_decode_tokens)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
spec_token_ids=spec_tokens,
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
if not (vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1")):
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
# The prompt token
|
||||
self.assertEqual(running_req.num_computed_tokens, 1)
|
||||
# The prompt token and the sampled token
|
||||
self.assertEqual(running_req.num_tokens, 2)
|
||||
# The prompt token, the sampled token, and the speculated tokens
|
||||
self.assertEqual(running_req.num_tokens_with_spec,
|
||||
2 + len(spec_tokens[i]))
|
||||
|
||||
# No draft or accepted tokens counted yet
|
||||
self.assertTrue(
|
||||
not engine_core_outputs
|
||||
or (engine_core_outputs[0].scheduler_stats.spec_decoding_stats
|
||||
is None))
|
||||
|
||||
# Schedule the speculated tokens for validation
|
||||
output = scheduler.schedule()
|
||||
self.assertEqual(len(output.scheduled_new_reqs), 0)
|
||||
# The sampled token and speculated tokens
|
||||
self.assertEqual(
|
||||
output.total_num_scheduled_tokens,
|
||||
len(requests) + sum(len(ids) for ids in spec_tokens))
|
||||
for i in range(len(requests)):
|
||||
req_id = requests[i].request_id
|
||||
self.assertEqual(output.num_scheduled_tokens[req_id],
|
||||
1 + len(spec_tokens[i]))
|
||||
if spec_tokens[i]:
|
||||
self.assertEqual(
|
||||
len(output.scheduled_spec_decode_tokens[req_id]),
|
||||
len(spec_tokens[i]))
|
||||
else:
|
||||
self.assertNotIn(req_id,
|
||||
output.scheduled_spec_decode_tokens)
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
|
||||
engine_core_outputs = scheduler.update_from_output(
|
||||
output, model_runner_output)
|
||||
|
||||
scheduler_stats = engine_core_outputs[0].scheduler_stats \
|
||||
if engine_core_outputs else None
|
||||
if expected[0] == 0:
|
||||
self.assertIsNone(scheduler_stats.spec_decoding_stats)
|
||||
else:
|
||||
self.assertIsNotNone(scheduler_stats.spec_decoding_stats)
|
||||
stats = scheduler_stats.spec_decoding_stats
|
||||
self.assertEqual(stats.num_drafts, expected[0])
|
||||
self.assertEqual(stats.num_draft_tokens, expected[1])
|
||||
self.assertEqual(stats.num_accepted_tokens, expected[2])
|
||||
self.assertEqual(stats.num_accepted_tokens_per_pos,
|
||||
expected[3])
|
||||
|
||||
def assert_scheduler_empty(self, scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
scheduler = self.create_scheduler()
|
||||
self.assertEqual(len(scheduler.requests), 0)
|
||||
self.assertEqual(len(scheduler.waiting), 0)
|
||||
self.assertEqual(len(scheduler.running), 0)
|
||||
self.assertEqual(len(scheduler.finished_req_ids), 0)
|
||||
|
||||
# EncoderCacheManager.
|
||||
self.assertEqual(len(scheduler.encoder_cache_manager.freed), 0)
|
||||
self.assertEqual(len(scheduler.encoder_cache_manager.cached), 0)
|
||||
|
||||
# KVCache Manager.
|
||||
self.assertEqual(
|
||||
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks), 0)
|
||||
self.assertEqual(
|
||||
len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block), 0)
|
||||
num_free_blocks = (scheduler.kv_cache_manager.block_pool.
|
||||
free_block_queue.num_free_blocks)
|
||||
self.assertEqual(
|
||||
num_free_blocks,
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
self.assertEqual(block.ref_cnt, 0)
|
||||
|
||||
def test_memory_leak(self):
|
||||
"""Test that we do not have a memory leak."""
|
||||
scheduler = self.create_scheduler()
|
||||
NUM_REQUESTS = 5
|
||||
NUM_TOKENS = 10
|
||||
MAX_TOKENS = 10
|
||||
requests = create_requests(num_requests=NUM_REQUESTS,
|
||||
num_tokens=NUM_TOKENS,
|
||||
max_tokens=MAX_TOKENS)
|
||||
|
||||
# Add each request.
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Iterate until done.
|
||||
while True:
|
||||
scheduler_output = scheduler.schedule()
|
||||
if len(scheduler.running) == 0:
|
||||
break
|
||||
model_runner_output = make_output(scheduler)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm no memory leak.
|
||||
self.assert_scheduler_empty(scheduler)
|
||||
188
tests/ut/device_allocator/test_camem.py
Normal file
188
tests/ut/device_allocator/test_camem.py
Normal file
@@ -0,0 +1,188 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.device_allocator.camem import (AllocationData, CaMemAllocator,
|
||||
create_and_map,
|
||||
find_loaded_library,
|
||||
get_pluggable_allocator,
|
||||
unmap_and_release)
|
||||
|
||||
|
||||
def dummy_malloc(args):
|
||||
pass
|
||||
|
||||
|
||||
def dummy_free(ptr):
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
|
||||
class TestCaMem(PytestBase):
|
||||
|
||||
def test_find_loaded_library_success_and_not_found(self):
|
||||
path = find_loaded_library("libc")
|
||||
assert path is not None, "Expected to find libc library"
|
||||
assert path.endswith(".so.6") or ".so" in path
|
||||
assert "libc" in path
|
||||
|
||||
path = find_loaded_library("non_existent_library")
|
||||
assert path is None, "Expected to not find non-existent library"
|
||||
|
||||
@pytest.mark.parametrize("handle", [
|
||||
(1, 2, 3),
|
||||
("device", 99),
|
||||
(None, ),
|
||||
])
|
||||
def test_create_and_map_calls_python_create_and_map(self, handle):
|
||||
with patch("vllm_ascend.device_allocator.camem.python_create_and_map"
|
||||
) as mock_create:
|
||||
create_and_map(handle)
|
||||
mock_create.assert_called_once_with(*handle)
|
||||
|
||||
@pytest.mark.parametrize("handle", [
|
||||
(42, "bar"),
|
||||
("foo", ),
|
||||
])
|
||||
def test_unmap_and_release_calls_python_unmap_and_release(self, handle):
|
||||
with patch(
|
||||
"vllm_ascend.device_allocator.camem.python_unmap_and_release"
|
||||
) as mock_release:
|
||||
unmap_and_release(handle)
|
||||
mock_release.assert_called_once_with(*handle)
|
||||
|
||||
@patch("vllm_ascend.device_allocator.camem.init_module")
|
||||
@patch(
|
||||
"vllm_ascend.device_allocator.camem.torch.npu.memory.NPUPluggableAllocator"
|
||||
)
|
||||
def test_get_pluggable_allocator(self, mock_allocator_class,
|
||||
mock_init_module):
|
||||
mock_allocator_instance = MagicMock()
|
||||
mock_allocator_class.return_value = mock_allocator_instance
|
||||
|
||||
def side_effect_malloc_and_free(malloc_fn, free_fn):
|
||||
malloc_fn((1, 2, 3))
|
||||
free_fn(123)
|
||||
|
||||
mock_init_module.side_effect = side_effect_malloc_and_free
|
||||
|
||||
allocator = get_pluggable_allocator(dummy_malloc, dummy_free)
|
||||
mock_init_module.assert_called_once_with(dummy_malloc, dummy_free)
|
||||
assert allocator == mock_allocator_instance
|
||||
|
||||
def test_singleton_behavior(self):
|
||||
instance1 = CaMemAllocator.get_instance()
|
||||
instance2 = CaMemAllocator.get_instance()
|
||||
assert instance1 is instance2
|
||||
|
||||
def test_python_malloc_and_free_callback(self):
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
|
||||
# mock allocation_handle
|
||||
handle = (1, 100, 1234, 0)
|
||||
allocator.current_tag = "test_tag"
|
||||
|
||||
allocator.python_malloc_callback(handle)
|
||||
# check pointer_to_data store data
|
||||
ptr = handle[2]
|
||||
assert ptr in allocator.pointer_to_data
|
||||
data = allocator.pointer_to_data[ptr]
|
||||
assert data.handle == handle
|
||||
assert data.tag == "test_tag"
|
||||
|
||||
# check free callback with cpu_backup_tensor
|
||||
data.cpu_backup_tensor = torch.zeros(1)
|
||||
result_handle = allocator.python_free_callback(ptr)
|
||||
assert result_handle == handle
|
||||
assert ptr not in allocator.pointer_to_data
|
||||
assert data.cpu_backup_tensor is None
|
||||
|
||||
@patch("vllm_ascend.device_allocator.camem.unmap_and_release")
|
||||
@patch("vllm_ascend.device_allocator.camem.memcpy")
|
||||
def test_sleep_offload_and_discard(self, mock_memcpy, mock_unmap):
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
|
||||
# prepare allocation, one tag match,one not match
|
||||
handle1 = (1, 10, 1000, 0)
|
||||
data1 = AllocationData(handle1, "tag1")
|
||||
handle2 = (2, 20, 2000, 0)
|
||||
data2 = AllocationData(handle2, "tag2")
|
||||
allocator.pointer_to_data = {
|
||||
1000: data1,
|
||||
2000: data2,
|
||||
}
|
||||
|
||||
# mock is_pin_memory_available, return False as some machine only has cpu
|
||||
with patch(
|
||||
"vllm_ascend.device_allocator.camem.NPUPlatform.is_pin_memory_available",
|
||||
return_value=False):
|
||||
allocator.sleep(offload_tags="tag1")
|
||||
|
||||
# only offload tag1, other tag2 call unmap_and_release
|
||||
assert data1.cpu_backup_tensor is not None
|
||||
assert data2.cpu_backup_tensor is None
|
||||
mock_unmap.assert_any_call(handle1)
|
||||
mock_unmap.assert_any_call(handle2)
|
||||
assert mock_unmap.call_count == 2
|
||||
assert mock_memcpy.called
|
||||
|
||||
@patch("vllm_ascend.device_allocator.camem.create_and_map")
|
||||
@patch("vllm_ascend.device_allocator.camem.memcpy")
|
||||
def test_wake_up_loads_and_clears_cpu_backup(self, mock_memcpy,
|
||||
mock_create_and_map):
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
|
||||
handle = (1, 10, 1000, 0)
|
||||
tensor = torch.zeros(5, dtype=torch.uint8)
|
||||
data = AllocationData(handle, "tag1", cpu_backup_tensor=tensor)
|
||||
allocator.pointer_to_data = {1000: data}
|
||||
|
||||
allocator.wake_up(tags=["tag1"])
|
||||
|
||||
mock_create_and_map.assert_called_once_with(handle)
|
||||
assert data.cpu_backup_tensor is None
|
||||
assert mock_memcpy.called
|
||||
|
||||
def test_use_memory_pool_context_manager(self):
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
old_tag = allocator.current_tag
|
||||
|
||||
# mock use_memory_pool_with_allocator
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.__enter__.return_value = "data"
|
||||
mock_ctx.__exit__.return_value = None
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.device_allocator.camem.use_memory_pool_with_allocator",
|
||||
return_value=mock_ctx):
|
||||
with allocator.use_memory_pool(tag="my_tag"):
|
||||
assert allocator.current_tag == "my_tag"
|
||||
# restore old tag after context manager exits
|
||||
assert allocator.current_tag == old_tag
|
||||
|
||||
def test_get_current_usage(self):
|
||||
allocator = CaMemAllocator.get_instance()
|
||||
|
||||
allocator.pointer_to_data = {
|
||||
1: AllocationData((0, 100, 1, 0), "tag"),
|
||||
2: AllocationData((0, 200, 2, 0), "tag"),
|
||||
}
|
||||
|
||||
usage = allocator.get_current_usage()
|
||||
assert usage == 300
|
||||
84
tests/ut/distributed/device_communicators/test_pyhccl.py
Normal file
84
tests/ut/distributed/device_communicators/test_pyhccl.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl import \
|
||||
PyHcclCommunicator
|
||||
|
||||
|
||||
class MockHcclLib:
|
||||
pass
|
||||
|
||||
|
||||
class MockUniqueId:
|
||||
pass
|
||||
|
||||
|
||||
class TestPyHcclCommunicator(TestBase):
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "1"})
|
||||
def test_world_size_1_return_early(self):
|
||||
comm = PyHcclCommunicator(
|
||||
group=StatelessProcessGroup(0, 1, None, None),
|
||||
device="npu:0",
|
||||
)
|
||||
self.assertTrue(comm.disabled)
|
||||
self.assertFalse(comm.available)
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "0", "WORLD_SIZE": "2"})
|
||||
def test_load_hccl_fail(self):
|
||||
comm = PyHcclCommunicator(group=StatelessProcessGroup(
|
||||
0, 2, None, None),
|
||||
device="npu:0",
|
||||
library_path="/not/exist/path/libhccl.so")
|
||||
self.assertTrue(comm.disabled)
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
|
||||
MockHcclLib)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
|
||||
MockUniqueId)
|
||||
@patch("torch.npu.device")
|
||||
@patch("vllm_ascend.utils.current_stream",
|
||||
return_value=MagicMock(npu_stream=5678))
|
||||
def test_stateless_group(self, *_):
|
||||
group = StatelessProcessGroup(rank=3,
|
||||
world_size=4,
|
||||
store=None,
|
||||
socket=None)
|
||||
|
||||
comm = PyHcclCommunicator(group=group, device=3)
|
||||
|
||||
self.assertEqual(comm.rank, 3)
|
||||
self.assertEqual(comm.world_size, 4)
|
||||
|
||||
@patch.dict(os.environ, {"RANK": "1", "WORLD_SIZE": "2"})
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.HCCLLibrary",
|
||||
MockHcclLib)
|
||||
@patch(
|
||||
"vllm_ascend.distributed.device_communicators.pyhccl_wrapper.hcclUniqueId",
|
||||
MockUniqueId)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="nccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.distributed.broadcast")
|
||||
@patch("torch.npu.device")
|
||||
@patch("vllm_ascend.utils.current_stream",
|
||||
return_value=MagicMock(npu_stream=1234))
|
||||
def test_multi_gpu_pg_torch(
|
||||
self,
|
||||
*_,
|
||||
):
|
||||
fake_pg = MagicMock()
|
||||
comm = PyHcclCommunicator(group=fake_pg, device="npu:1")
|
||||
|
||||
self.assertEqual(comm.rank, 1)
|
||||
self.assertEqual(comm.world_size, 2)
|
||||
self.assertFalse(comm.available)
|
||||
self.assertTrue(comm.disabled)
|
||||
173
tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py
Normal file
173
tests/ut/distributed/device_communicators/test_pyhccl_wrapper.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.distributed.device_communicators.pyhccl_wrapper import (
|
||||
Function, HCCLLibrary, aclrtStream_t, buffer_type, hcclComm_t,
|
||||
hcclDataType_t, hcclDataTypeEnum, hcclRedOp_t, hcclRedOpTypeEnum,
|
||||
hcclResult_t, hcclUniqueId)
|
||||
|
||||
|
||||
class TestHcclUniqueId(TestBase):
|
||||
|
||||
def test_construct(self):
|
||||
uid = hcclUniqueId()
|
||||
uid.internal[0] = 12
|
||||
self.assertEqual(len(uid.internal), 4108)
|
||||
self.assertEqual(uid.internal[0], 12)
|
||||
|
||||
|
||||
class TestHcclDataTypeEnum(TestBase):
|
||||
|
||||
def test_torch_dtype_mapping(self):
|
||||
expected = {
|
||||
torch.int8: hcclDataTypeEnum.hcclInt8,
|
||||
torch.uint8: hcclDataTypeEnum.hcclUint8,
|
||||
torch.int32: hcclDataTypeEnum.hcclInt32,
|
||||
torch.int64: hcclDataTypeEnum.hcclInt64,
|
||||
torch.float16: hcclDataTypeEnum.hcclFloat16,
|
||||
torch.float32: hcclDataTypeEnum.hcclFloat32,
|
||||
torch.float64: hcclDataTypeEnum.hcclFloat64,
|
||||
torch.bfloat16: hcclDataTypeEnum.hcclBfloat16,
|
||||
}
|
||||
|
||||
for torch_dtype, expected_enum in expected.items():
|
||||
with self.subTest(torch_dtype=torch_dtype):
|
||||
self.assertEqual(hcclDataTypeEnum.from_torch(torch_dtype),
|
||||
expected_enum)
|
||||
|
||||
def test_unsupported_dtype_raises(self):
|
||||
with self.assertRaises(ValueError):
|
||||
hcclDataTypeEnum.from_torch(torch.complex64)
|
||||
|
||||
|
||||
class TestHcclRedOpTypeEnum(TestBase):
|
||||
|
||||
def test_torch_reduce_op_mapping(self):
|
||||
expected = {
|
||||
ReduceOp.SUM: hcclRedOpTypeEnum.hcclSum,
|
||||
ReduceOp.PRODUCT: hcclRedOpTypeEnum.hcclProd,
|
||||
ReduceOp.MAX: hcclRedOpTypeEnum.hcclMax,
|
||||
ReduceOp.MIN: hcclRedOpTypeEnum.hcclMin,
|
||||
}
|
||||
|
||||
for torch_op, expected_enum in expected.items():
|
||||
with self.subTest(torch_op=torch_op):
|
||||
self.assertEqual(hcclRedOpTypeEnum.from_torch(torch_op),
|
||||
expected_enum)
|
||||
|
||||
def test_unsupported_op_raises(self):
|
||||
unsupported_op = "NOT_EXIST"
|
||||
with self.assertRaises(ValueError):
|
||||
hcclRedOpTypeEnum.from_torch(unsupported_op)
|
||||
|
||||
|
||||
class TestFunction(TestBase):
|
||||
|
||||
def test_construct_with_valid_args(self):
|
||||
func = Function(name="foo", restype=int, argtypes=[int, str, float])
|
||||
self.assertEqual(func.name, "foo")
|
||||
self.assertIs(func.restype, int)
|
||||
self.assertEqual(func.argtypes, [int, str, float])
|
||||
|
||||
|
||||
class TestHCLLLibrary(TestBase):
|
||||
|
||||
def test_init_with_nonexistent_so(self):
|
||||
fake_path = "/definitely/not/exist/libhccl.so"
|
||||
with self.assertRaises(OSError):
|
||||
HCCLLibrary(fake_path)
|
||||
|
||||
def test_hccl_get_error_string(self):
|
||||
lib = MagicMock(sepc=HCCLLibrary)
|
||||
mock_fn = MagicMock()
|
||||
mock_fn.return_value = "HCCL internal error"
|
||||
lib.hcclGetErrorString = mock_fn
|
||||
|
||||
result = hcclResult_t(1)
|
||||
msg = lib.hcclGetErrorString(result)
|
||||
self.assertEqual(msg, "HCCL internal error")
|
||||
mock_fn.assert_called_once()
|
||||
|
||||
def test_hccl_check(self):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
mock_fn = MagicMock()
|
||||
mock_fn.return_value = "fake error"
|
||||
lib.hcclGetErrorString = mock_fn
|
||||
result = hcclResult_t(123)
|
||||
with self.assertRaises(RuntimeError) as cm:
|
||||
lib.HCCL_CHECK(result)
|
||||
|
||||
self.assertEqual(str(cm.exception), "HCCL error: fake error")
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_get_uniqueId(self, mock_HCCL_CHECK):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclGetRootInfo": MagicMock(return_value=0)}
|
||||
unique_id = lib.hcclGetUniqueId()
|
||||
self.assertIsInstance(unique_id, hcclUniqueId)
|
||||
lib._funcs["HcclGetRootInfo"].assert_called_once()
|
||||
mock_HCCL_CHECK.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_comm_initRank(self, mock_hccl_check):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclCommInitRootInfo": MagicMock(return_value=0)}
|
||||
|
||||
world_size = 4
|
||||
unique_id = hcclUniqueId()
|
||||
rank = 1
|
||||
|
||||
comm = lib.hcclCommInitRank(world_size, unique_id, rank)
|
||||
self.assertIsInstance(comm, hcclComm_t)
|
||||
lib._funcs["HcclCommInitRootInfo"].assert_called_once()
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_all_reduce(self, mock_hccl_check):
|
||||
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclAllReduce": MagicMock(return_value=0)}
|
||||
sendbuff = buffer_type()
|
||||
recvbuff = buffer_type()
|
||||
count = 10
|
||||
datatype = hcclDataType_t(1)
|
||||
op = hcclRedOp_t(0)
|
||||
comm = hcclComm_t()
|
||||
stream = aclrtStream_t()
|
||||
|
||||
lib.hcclAllReduce(sendbuff, recvbuff, count, datatype, op, comm,
|
||||
stream)
|
||||
|
||||
lib._funcs["HcclAllReduce"].assert_called_once_with(
|
||||
sendbuff, recvbuff, count, datatype, op, comm, stream)
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hccl_broad_cast(self, mock_hccl_check):
|
||||
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclBroadcast": MagicMock(return_value=0)}
|
||||
buff = buffer_type()
|
||||
count = 10
|
||||
datatype = 1
|
||||
root = 0
|
||||
comm = hcclComm_t()
|
||||
stream = aclrtStream_t()
|
||||
|
||||
lib.hcclBroadcast(buff, count, datatype, root, comm, stream)
|
||||
|
||||
lib._funcs["HcclBroadcast"].assert_called_once_with(
|
||||
buff, count, datatype, root, comm, stream)
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
|
||||
@patch.object(HCCLLibrary, "HCCL_CHECK")
|
||||
def test_hcclCommDestroy_success(self, mock_hccl_check):
|
||||
lib = HCCLLibrary.__new__(HCCLLibrary)
|
||||
lib._funcs = {"HcclCommDestroy": MagicMock(return_value=0)}
|
||||
comm = hcclComm_t()
|
||||
lib.hcclCommDestroy(comm)
|
||||
lib._funcs["HcclCommDestroy"].assert_called_once_with(comm)
|
||||
mock_hccl_check.assert_called_once_with(0)
|
||||
89
tests/ut/distributed/test_communicator.py
Normal file
89
tests/ut/distributed/test_communicator.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.distributed.communicator import NPUCommunicator
|
||||
|
||||
|
||||
class TestNPUCommunicator(unittest.TestCase):
|
||||
|
||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
||||
@patch("torch.distributed.get_process_group_ranks",
|
||||
return_value={
|
||||
0: 0,
|
||||
1: 1
|
||||
})
|
||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.npu.device")
|
||||
def test_all_to_all_with_sizes(self, *_):
|
||||
|
||||
def patched_all_to_all(output_tensor_list,
|
||||
input_tensor_list,
|
||||
group=None,
|
||||
async_op=False):
|
||||
output_tensor_list[:] = ([
|
||||
torch.tensor([10, 20]),
|
||||
torch.tensor([50, 60])
|
||||
])
|
||||
|
||||
torch.distributed.all_to_all = patched_all_to_all
|
||||
|
||||
scatter_sizes = [2, 2]
|
||||
gather_sizes = [2, 2]
|
||||
input_ = torch.tensor([10, 20, 30, 40])
|
||||
|
||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
||||
|
||||
output = comm.all_to_all(input_,
|
||||
scatter_sizes=scatter_sizes,
|
||||
gather_sizes=gather_sizes)
|
||||
|
||||
assert output.tolist() == [10, 20, 50, 60]
|
||||
|
||||
@patch("vllm.config.get_current_vllm_config", return_value=None)
|
||||
@patch("torch.npu.current_device", return_value=MagicMock())
|
||||
@patch("torch.npu.set_device", return_value=MagicMock())
|
||||
@patch("torch.distributed.get_process_group_ranks",
|
||||
return_value={
|
||||
0: 0,
|
||||
1: 1
|
||||
})
|
||||
@patch("torch.distributed.get_group_rank", return_value={0: 0, 1: 1})
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.is_initialized", return_value=True)
|
||||
@patch("torch.distributed.get_backend", return_value="hccl")
|
||||
@patch("torch.distributed.get_rank", return_value=1)
|
||||
@patch("torch.distributed.get_world_size", return_value=2)
|
||||
@patch("torch.distributed.get_process_group_ranks", return_value=[0, 1])
|
||||
@patch("torch.npu.device")
|
||||
def test_all_to_all_without_sizes(self, *_):
|
||||
|
||||
def patched_all_to_all(output_tensor_list,
|
||||
input_tensor_list,
|
||||
group=None,
|
||||
async_op=False):
|
||||
output_tensor_list[:] = ([
|
||||
torch.tensor([[10, 20]]),
|
||||
torch.tensor([[50, 60]])
|
||||
])
|
||||
|
||||
torch.distributed.all_to_all = patched_all_to_all
|
||||
|
||||
input_ = torch.tensor([[10, 20], [30, 40]])
|
||||
|
||||
comm = NPUCommunicator(cpu_group=dist.group.WORLD)
|
||||
output = comm.all_to_all(input_, scatter_dim=0, gather_dim=0)
|
||||
|
||||
assert output.tolist() == [[10, 20], [50, 60]]
|
||||
139
tests/ut/distributed/test_distributed_tensor_parallel.py
Normal file
139
tests/ut/distributed/test_distributed_tensor_parallel.py
Normal file
@@ -0,0 +1,139 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import importlib
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
_gather_along_first_dim, _gather_along_last_dim,
|
||||
_reduce_scatter_along_first_dim, _reduce_scatter_along_last_dim,
|
||||
all_to_all_hp2sp, all_to_all_sp2hp)
|
||||
|
||||
|
||||
class TestDistributedCommunication(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def context(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.npu.current_device", return_value="cpu")
|
||||
mocker.patch("torch.distributed.get_world_size", return_value=4)
|
||||
|
||||
mocker.patch("torch.distributed.get_rank", return_value=0)
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16), (8, 16)),
|
||||
(4, torch.randn(8, 16), (32, 16))])
|
||||
def test_gather_along_first_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("test_tensor, output_split_sizes, expected", [
|
||||
(torch.randn(8, 16), [5, 10, 15, 2], (32, 16)),
|
||||
])
|
||||
def test_gather_along_first_dim_unequal_split(self, test_tensor, expected,
|
||||
output_split_sizes,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_first_dim"""
|
||||
|
||||
result = _gather_along_first_dim(test_tensor, mocker.MagicMock(),
|
||||
output_split_sizes)
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("world_size, test_tensor, expected",
|
||||
[(1, torch.randn(8, 16, 32), (8, 16, 32)),
|
||||
(4, torch.randn(8, 16, 32), (8, 16, 32 * 4))])
|
||||
def test_gather_along_last_dim(self, test_tensor, expected, world_size,
|
||||
mocker: MockerFixture):
|
||||
"""test _gather_along_last_dim"""
|
||||
mocker.patch("torch.distributed.get_world_size",
|
||||
return_value=world_size)
|
||||
|
||||
result = _gather_along_last_dim(test_tensor, mocker.MagicMock())
|
||||
|
||||
assert result.shape == expected
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((32, 16), (8, 16)),
|
||||
((40, 10), (10, 10)),
|
||||
])
|
||||
def test_reduce_scatter_along_first_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_first_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("input_shape,expected_shape", [
|
||||
((8, 16, 32), (8, 16, 8)),
|
||||
])
|
||||
def test_reduce_scatter_along_last_dim(self, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = _reduce_scatter_along_last_dim(input_tensor,
|
||||
mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("func,input_shape,expected_shape", [
|
||||
("all_gather_last_dim_from_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 128)),
|
||||
("reduce_scatter_to_sequence_parallel_region", (32, 16), (8, 16)),
|
||||
("reduce_scatter_last_dim_to_tensor_parallel_region", (8, 16, 32),
|
||||
(8, 16, 8)),
|
||||
("gather_from_sequence_parallel_region", (8, 16), (32, 16)),
|
||||
])
|
||||
def test_wrapper_functions(self, func, input_shape, expected_shape,
|
||||
mocker: MockerFixture):
|
||||
"""test wrapper funcs"""
|
||||
mod = importlib.import_module(
|
||||
'vllm_ascend.distributed.tensor_parallel')
|
||||
globals = mod.__dict__
|
||||
test_func = globals[func]
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = test_func(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((8, 16), (32, 4)), # [num_tokens/TP, H] -> [num_tokens, H/TP]
|
||||
])
|
||||
def test_all_to_all_sp2hp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_sp2hp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_shape,output_shape",
|
||||
[
|
||||
((32, 4), (8, 16)), # [num_tokens, H/TP] -> [num_tokens/TP, H]
|
||||
])
|
||||
def test_all_to_all_hp2sp(self, input_shape, output_shape,
|
||||
mocker: MockerFixture):
|
||||
input_tensor = torch.randn(*input_shape)
|
||||
result = all_to_all_hp2sp(input_tensor, mocker.MagicMock())
|
||||
assert result.shape == output_shape
|
||||
44
tests/ut/distributed/test_parallel_state.py
Normal file
44
tests/ut/distributed/test_parallel_state.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def parallel_config():
|
||||
return ParallelConfig(data_parallel_size=2,
|
||||
tensor_parallel_size=2,
|
||||
pipeline_parallel_size=2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
with patch('torch.distributed.is_initialized', return_value=True), \
|
||||
patch('torch.distributed.get_world_size', return_value=8), \
|
||||
patch('torch.distributed.get_backend', return_value='nccl'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_world_group') as mock_group:
|
||||
mock_group.return_value.local_rank = 0
|
||||
mock_group.return_value.device_group = MagicMock()
|
||||
yield
|
||||
|
||||
|
||||
def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
init_ascend_model_parallel(parallel_config)
|
||||
|
||||
mc2_group = get_mc2_group()
|
||||
assert mc2_group is not None
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
assert lmheadtp_group is not None
|
||||
|
||||
destroy_ascend_model_parallel()
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
28
tests/ut/fake_weight/config.json
Normal file
28
tests/ut/fake_weight/config.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"_name_or_path": "facebook/opt-125m",
|
||||
"activation_dropout": 0.0,
|
||||
"activation_function": "relu",
|
||||
"architectures": [
|
||||
"OPTForCausalLM"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 2,
|
||||
"do_layer_norm_before": true,
|
||||
"dropout": 0.1,
|
||||
"eos_token_id": 2,
|
||||
"ffn_dim": 3072,
|
||||
"hidden_size": 768,
|
||||
"init_std": 0.02,
|
||||
"layerdrop": 0.0,
|
||||
"max_position_embeddings": 2048,
|
||||
"model_type": "opt",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"prefix": "</s>",
|
||||
"torch_dtype": "float16",
|
||||
"transformers_version": "4.21.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 50272,
|
||||
"word_embed_proj_dim": 768
|
||||
}
|
||||
96
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
96
tests/ut/kv_connector/test_llmdatadist_connector.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
import types
|
||||
|
||||
from tests.ut.kv_connector.utils import (create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
from vllm_ascend.distributed.llmdatadist_c_mgr_connector import (
|
||||
LLMDataDistCMgrConnectorMetadata, LLMDataDistCMgrConnectorWorker, LLMRole)
|
||||
|
||||
|
||||
def test_basic_inferface():
|
||||
"""Unit test for basic LLMDataDistCMgrConnector interface functionality."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
request_id = request.request_id
|
||||
|
||||
scheduler.add_request(request)
|
||||
|
||||
# Remote Prefill, triggers LLMDataDistCMgrConnectorMetadata.
|
||||
scheduler_output = scheduler.schedule()
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, LLMDataDistCMgrConnectorMetadata)
|
||||
|
||||
assert len(kv_connector_metadata.requests) == 1
|
||||
assert request_id in kv_connector_metadata.requests
|
||||
req_meta = kv_connector_metadata.requests[request_id]
|
||||
|
||||
for block_id, block in zip(
|
||||
req_meta.local_block_ids, scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id]):
|
||||
assert block_id == block.block_id
|
||||
|
||||
|
||||
def test_read_agent_metadata():
|
||||
rank_table = {
|
||||
"version":
|
||||
"1.2",
|
||||
"server_count":
|
||||
"2",
|
||||
"prefill_device_list": [{
|
||||
"server_id": "192.168.1.1",
|
||||
"device_id": "0",
|
||||
"device_ip": "10.30.0.1",
|
||||
"cluster_id": "0",
|
||||
}, {
|
||||
"server_id": "192.168.1.1",
|
||||
"device_id": "1",
|
||||
"device_ip": "10.30.0.2",
|
||||
"cluster_id": "1",
|
||||
}, {
|
||||
"server_id": "192.168.1.2",
|
||||
"device_id": "0",
|
||||
"device_ip": "10.30.0.3",
|
||||
"cluster_id": "2",
|
||||
}, {
|
||||
"server_id": "192.168.1.2",
|
||||
"device_id": "1",
|
||||
"device_ip": "10.30.0.4",
|
||||
"cluster_id": "3",
|
||||
}]
|
||||
}
|
||||
|
||||
def get_device_ip(worker_local_ip, worker_tp_rank, worker_visible_devices):
|
||||
old_visible_devices = os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "")
|
||||
worker = types.SimpleNamespace()
|
||||
worker.local_ip = worker_local_ip
|
||||
worker.tp_rank = worker_tp_rank
|
||||
worker.llm_datadist_role = LLMRole.PROMPT
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = worker_visible_devices
|
||||
agent_metadata = LLMDataDistCMgrConnectorWorker.read_agent_metadata(
|
||||
worker, rank_table)
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = old_visible_devices
|
||||
return agent_metadata.device_ip
|
||||
|
||||
assert get_device_ip("192.168.1.1", 0, "0") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 0, "1") == "10.30.0.2"
|
||||
assert get_device_ip("192.168.1.2", 0, "0") == "10.30.0.3"
|
||||
assert get_device_ip("192.168.1.2", 0, "1") == "10.30.0.4"
|
||||
assert get_device_ip("192.168.1.1", 0, "0,1") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 1, "0,1") == "10.30.0.2"
|
||||
assert get_device_ip("192.168.1.1", 0, "") == "10.30.0.1"
|
||||
assert get_device_ip("192.168.1.1", 1, "") == "10.30.0.2"
|
||||
998
tests/ut/kv_connector/test_mooncake_connector.py
Normal file
998
tests/ut/kv_connector/test_mooncake_connector.py
Normal file
@@ -0,0 +1,998 @@
|
||||
import os
|
||||
import queue
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from vllm.utils import make_zmq_path
|
||||
|
||||
fake_engine = types.ModuleType("mooncake.engine")
|
||||
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
|
||||
sys.modules["mooncake.engine"] = fake_engine
|
||||
|
||||
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
||||
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
||||
KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector,
|
||||
MooncakeConnectorMetadata, MooncakeConnectorScheduler,
|
||||
MooncakeConnectorWorker, ReqMeta, ensure_zmq_recv, ensure_zmq_send,
|
||||
group_concurrent_contiguous, string_to_int64_hash, zmq_ctx)
|
||||
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
DONE_RECVING_MSG = b"done_recving_msg"
|
||||
|
||||
|
||||
class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
|
||||
def test_init_basic_properties(self):
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tracker = KVCacheTaskTracker()
|
||||
self.tracker.finished_requests = set()
|
||||
self.tracker.done_task_lock = threading.Lock()
|
||||
|
||||
def test_empty_requests(self):
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, set())
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
def test_single_request(self):
|
||||
self.tracker.finished_requests = {"req_123"}
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, {"req_123"})
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
def test_multiple_requests(self):
|
||||
self.tracker.finished_requests = {"req_1", "req_2", "req_3"}
|
||||
result = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertSetEqual(result, {"req_1", "req_2", "req_3"})
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_concurrent_access(self, mock_logger):
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
self.tracker.finished_requests = {"req_1", "req_2"}
|
||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [
|
||||
executor.submit(self.tracker.get_and_clear_finished_requests)
|
||||
for _ in range(3)
|
||||
]
|
||||
results = [f.result() for f in futures]
|
||||
self.assertEqual(sum(1 for r in results if r), 1)
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
|
||||
|
||||
class TestKVCacheSendingThreadInit(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
'local_engine_id': 'engine_1',
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
'metadata': MagicMock(),
|
||||
'ready_event': threading.Event()
|
||||
}
|
||||
self.threads = []
|
||||
|
||||
def tearDown(self):
|
||||
for thread in self.threads:
|
||||
if hasattr(thread, 'task_tracker') and hasattr(
|
||||
thread.task_tracker, 'socket'):
|
||||
thread.task_tracker.socket.close()
|
||||
if hasattr(thread, 'is_alive') and thread.is_alive():
|
||||
thread.join(timeout=0.1)
|
||||
|
||||
def test_thread_daemon_property(self):
|
||||
thread = KVCacheSendingThread(**self.common_args)
|
||||
self.threads.append(thread)
|
||||
self.assertTrue(thread.daemon)
|
||||
|
||||
def test_thread_name_format(self):
|
||||
thread = KVCacheSendingThread(**self.common_args)
|
||||
self.threads.append(thread)
|
||||
self.assertEqual(thread.name, "KVCacheSendingThread")
|
||||
|
||||
def test_ready_event_reference(self):
|
||||
custom_event = threading.Event()
|
||||
args = self.common_args.copy()
|
||||
args['ready_event'] = custom_event
|
||||
thread = KVCacheSendingThread(**args)
|
||||
self.threads.append(thread)
|
||||
self.assertIs(thread.ready_event, custom_event)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedRequests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.common_args = {
|
||||
'tp_rank': 1,
|
||||
'decode_tp_size': 4,
|
||||
'local_engine_id': 'engine_1',
|
||||
'side_channel_host': 'localhost',
|
||||
'side_channel_port': 5555,
|
||||
'metadata': {
|
||||
"test": "metadata"
|
||||
},
|
||||
'ready_event': threading.Event()
|
||||
}
|
||||
self.thread = KVCacheSendingThread(**self.common_args)
|
||||
|
||||
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
|
||||
def test_get_and_clear_finished_requests(self, mock_get_clear):
|
||||
expected_requests = {'req1', 'req2'}
|
||||
mock_get_clear.return_value = expected_requests
|
||||
result = self.thread.get_and_clear_finished_requests()
|
||||
mock_get_clear.assert_called_once()
|
||||
self.assertEqual(result, expected_requests)
|
||||
|
||||
|
||||
class TestKVCacheSendingThread(unittest.TestCase):
|
||||
|
||||
def test_run_handles_get_meta_and_done_recv_msgs(self):
|
||||
ready_event = threading.Event()
|
||||
metadata = MooncakeAgentMetadata(
|
||||
engine_id="engine1",
|
||||
te_rpc_port=9090,
|
||||
kv_caches_base_addr=[12345678],
|
||||
num_blocks=2,
|
||||
)
|
||||
host = "127.0.0.1"
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(('', 0))
|
||||
free_port = s.getsockname()[1]
|
||||
|
||||
thread = KVCacheSendingThread(
|
||||
tp_rank=0,
|
||||
decode_tp_size=1,
|
||||
local_engine_id="engine1",
|
||||
side_channel_host=host,
|
||||
side_channel_port=free_port,
|
||||
metadata=metadata,
|
||||
ready_event=ready_event,
|
||||
)
|
||||
thread.start()
|
||||
self.assertTrue(ready_event.wait(timeout=3),
|
||||
"Server thread startup timeout")
|
||||
|
||||
context = zmq.Context() # type: ignore
|
||||
sock = context.socket(zmq.DEALER) # type: ignore
|
||||
sock.connect(f"tcp://{host}:{free_port}")
|
||||
encoder = msgspec.msgpack.Encoder()
|
||||
decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata)
|
||||
|
||||
sock.send_multipart([b"", encoder.encode((GET_META_MSG, ))])
|
||||
frames = sock.recv_multipart()
|
||||
self.assertEqual(frames[0], b"")
|
||||
meta = decoder.decode(frames[1])
|
||||
self.assertEqual(meta.engine_id, "engine1")
|
||||
self.assertEqual(meta.kv_caches_base_addr, [12345678])
|
||||
self.assertEqual(meta.num_blocks, 2)
|
||||
|
||||
req_id = "request_42"
|
||||
sock.send_multipart(
|
||||
[b"", encoder.encode((DONE_RECVING_MSG, req_id, 0))])
|
||||
frames = sock.recv_multipart()
|
||||
self.assertEqual(frames[0], b"")
|
||||
self.assertEqual(frames[1], b"ACK")
|
||||
self.assertIn(req_id, thread.task_tracker.finished_requests)
|
||||
|
||||
sock.close()
|
||||
context.term()
|
||||
|
||||
|
||||
class TestKVCacheRecvingThreadBasic(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
|
||||
def test_add_request(self):
|
||||
test_req = {
|
||||
"request_id": "req1",
|
||||
"local_block_ids": [1, 2],
|
||||
"remote_block_ids": [3, 4],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
}
|
||||
self.thread.add_request(**test_req)
|
||||
queued = self.thread.request_queue.get_nowait()
|
||||
self.assertEqual(queued["request_id"], "req1")
|
||||
self.assertEqual(queued["remote_host"], "localhost")
|
||||
|
||||
@patch.object(KVCacheTaskTracker, 'get_and_clear_finished_requests')
|
||||
def test_get_finished_requests(self, mock_tracker):
|
||||
mock_tracker.return_value = {"req1", "req2"}
|
||||
result = self.thread.get_and_clear_finished_requests()
|
||||
self.assertEqual(result, {"req1", "req2"})
|
||||
|
||||
|
||||
class TestSocketManagement(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.thread.remote_sockets = defaultdict(deque)
|
||||
self.thread.remote_poller = MagicMock()
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.zmq.Context')
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.make_zmq_socket')
|
||||
def test_get_remote_socket(self, mock_make_socket, mock_context):
|
||||
mock_sock = MagicMock()
|
||||
mock_make_socket.return_value = mock_sock
|
||||
test_host = "test_host"
|
||||
test_port = 12345
|
||||
|
||||
sock = self.thread._get_remote_socket(test_host, test_port)
|
||||
|
||||
self.assertEqual(sock, mock_sock)
|
||||
mock_make_socket.assert_called_once()
|
||||
args, kwargs = mock_make_socket.call_args
|
||||
self.assertEqual(kwargs.get('path'), 'tcp://test_host:12345')
|
||||
self.assertEqual(kwargs.get('socket_type'), zmq.REQ) # type: ignore
|
||||
self.assertFalse(kwargs.get('bind', True))
|
||||
self.thread.remote_poller.register.assert_called_with(
|
||||
mock_sock, zmq.POLLIN) # type: ignore
|
||||
|
||||
def test_return_socket_to_pool(self):
|
||||
mock_sock = MagicMock()
|
||||
test_host = "test_host"
|
||||
test_port = 12345
|
||||
test_path = make_zmq_path("tcp", test_host, test_port)
|
||||
|
||||
self.thread._return_remote_socket(mock_sock, test_host, test_port)
|
||||
|
||||
self.assertEqual(len(self.thread.remote_sockets[test_path]), 1)
|
||||
self.assertEqual(self.thread.remote_sockets[test_path][0], mock_sock)
|
||||
self.thread.remote_poller.register.assert_not_called()
|
||||
|
||||
|
||||
class TestCoreFunctionality(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.mock_queue = MagicMock()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.thread.request_queue = self.mock_queue
|
||||
self.test_req = {
|
||||
"request_id": "req1",
|
||||
"local_block_ids": [1, 2],
|
||||
"remote_block_ids": [3, 4],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
}
|
||||
self.thread.task_tracker = MagicMock()
|
||||
self.engine.batch_transfer_sync_read.return_value = 0
|
||||
self.thread.remote_te_port = {"remote_engine": {6666: 7777}}
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_transfer_kv_cache')
|
||||
@patch.object(KVCacheRecvingThread, '_send_done_recv_signal')
|
||||
def test_handle_request(self, mock_send, mock_transfer):
|
||||
self.thread._handle_request(self.test_req)
|
||||
mock_transfer.assert_called_once_with(self.test_req)
|
||||
mock_send.assert_called_once_with("req1", "localhost", 6666)
|
||||
self.thread.task_tracker.update_done_task_count.assert_called_once_with(
|
||||
"req1")
|
||||
self.mock_queue.task_done.assert_called_once()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_get_remote_metadata')
|
||||
def test_transfer_kv_cache(self, mock_get_meta):
|
||||
self.thread.kv_caches_base_addr["remote_engine"] = {
|
||||
6666: [0x3000, 0x4000]
|
||||
}
|
||||
|
||||
self.thread._transfer_kv_cache(self.test_req)
|
||||
|
||||
self.engine.batch_transfer_sync_read.assert_called_once()
|
||||
call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args
|
||||
self.assertEqual(call_args[0], "localhost:7777")
|
||||
self.assertIsInstance(call_args[1], list)
|
||||
self.assertIsInstance(call_args[2], list)
|
||||
self.assertIsInstance(call_args[3], list)
|
||||
self.assertEqual(len(call_args[1]), len(call_args[2]))
|
||||
self.assertEqual(len(call_args[1]), len(call_args[3]))
|
||||
mock_get_meta.assert_not_called()
|
||||
|
||||
def test_transfer_kv_cache_failure(self):
|
||||
self.engine.batch_transfer_sync_read.return_value = -1
|
||||
self.thread.kv_caches_base_addr["remote_engine"] = {
|
||||
6666: [0x3000, 0x4000]
|
||||
}
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.thread._transfer_kv_cache(self.test_req)
|
||||
|
||||
|
||||
class TestMetadataHandling(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.test_metadata = MooncakeAgentMetadata(
|
||||
engine_id="remote_engine",
|
||||
te_rpc_port=9090,
|
||||
kv_caches_base_addr=[0x3000, 0x4000],
|
||||
num_blocks=2)
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv')
|
||||
def test_get_remote_metadata_success(self, mock_recv, mock_send):
|
||||
mock_recv.return_value = msgspec.msgpack.encode(self.test_metadata)
|
||||
|
||||
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
|
||||
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
|
||||
mock_socket = MagicMock()
|
||||
mock_get_socket.return_value = mock_socket
|
||||
|
||||
self.thread._get_remote_metadata("host1", 5555)
|
||||
|
||||
mock_get_socket.assert_called_once_with("host1", 5555)
|
||||
mock_return_socket.assert_called_once_with(mock_socket, "host1",
|
||||
5555)
|
||||
mock_send.assert_called_once_with(
|
||||
mock_socket, self.thread.encoder.encode((GET_META_MSG, "")))
|
||||
mock_recv.assert_called_once_with(mock_socket,
|
||||
self.thread.remote_poller)
|
||||
self.assertEqual(
|
||||
self.thread.kv_caches_base_addr["remote_engine"][5555],
|
||||
[0x3000, 0x4000])
|
||||
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_send')
|
||||
@patch('vllm_ascend.distributed.mooncake_connector.ensure_zmq_recv',
|
||||
side_effect=Exception("Network error"))
|
||||
def test_get_remote_metadata_failure(self, mock_recv, mock_send):
|
||||
with patch.object(self.thread, '_get_remote_socket') as mock_get_socket, \
|
||||
patch.object(self.thread, '_return_remote_socket') as mock_return_socket:
|
||||
mock_socket = MagicMock()
|
||||
mock_get_socket.return_value = mock_socket
|
||||
|
||||
with self.assertRaises(Exception) as context:
|
||||
self.thread._get_remote_metadata("host1", 5555)
|
||||
|
||||
self.assertEqual(str(context.exception), "Network error")
|
||||
mock_return_socket.assert_called_once()
|
||||
|
||||
|
||||
class TestMainThreadLoop(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.engine = MagicMock()
|
||||
self.ready_event = threading.Event()
|
||||
self.thread = KVCacheRecvingThread(
|
||||
tp_rank=0,
|
||||
tp_size=4,
|
||||
engine=self.engine,
|
||||
local_engine_id="local_engine",
|
||||
local_handshake_port=5555,
|
||||
local_kv_caches_base_addr=[0x1000, 0x2000],
|
||||
block_len=[1024, 2048],
|
||||
ready_event=self.ready_event)
|
||||
self.thread.request_queue = queue.Queue()
|
||||
|
||||
@patch.object(KVCacheRecvingThread, '_handle_request')
|
||||
def test_run_loop_normal(self, mock_handle):
|
||||
test_request = {
|
||||
"request_id": "req1",
|
||||
"local_block_ids": [1, 2],
|
||||
"remote_block_ids": [3, 4],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_handshake_port": 6666,
|
||||
"remote_transfer_port": 7777
|
||||
}
|
||||
|
||||
self.thread.request_queue.put(test_request)
|
||||
self.thread.request_queue.put(None)
|
||||
|
||||
self.thread.start()
|
||||
time.sleep(0.1)
|
||||
self.thread.join(timeout=1.0)
|
||||
|
||||
self.assertTrue(self.thread.ready_event.is_set())
|
||||
mock_handle.assert_called_once_with(test_request)
|
||||
self.assertTrue(self.thread.request_queue.empty())
|
||||
|
||||
|
||||
class MockVllmConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.model_config = MagicMock()
|
||||
self.parallel_config = MagicMock()
|
||||
self.cache_config = MagicMock()
|
||||
self.kv_transfer_config = MagicMock()
|
||||
self.model_config.use_mla = True
|
||||
self.parallel_config.tensor_parallel_size = 2
|
||||
self.parallel_config.data_parallel_rank_local = 0
|
||||
self.parallel_config.data_parallel_size_local = 1
|
||||
self.cache_config.block_size = 16
|
||||
self.kv_transfer_config.kv_port = 5000
|
||||
self.kv_transfer_config.kv_role = 'kv_producer'
|
||||
self.kv_transfer_config.get_from_extra_config = MagicMock()
|
||||
self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: {
|
||||
"prefill": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
},
|
||||
"decode": {
|
||||
"tp_size": 2,
|
||||
"dp_size": 1
|
||||
}
|
||||
}.get(k, d)
|
||||
|
||||
|
||||
class MockRequest:
|
||||
|
||||
def __init__(self,
|
||||
request_id,
|
||||
prompt_token_ids=None,
|
||||
kv_transfer_params=None,
|
||||
status=None):
|
||||
self.request_id = request_id
|
||||
self.prompt_token_ids = prompt_token_ids or [1, 2, 3, 4]
|
||||
self.kv_transfer_params = kv_transfer_params or {}
|
||||
self.status = status or "running"
|
||||
self.output_token_ids = [101, 102]
|
||||
|
||||
|
||||
class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tracker = KVCacheTaskTracker()
|
||||
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time - 600)
|
||||
self.tracker.add_delayed_request("req_2", current_time)
|
||||
result = self.tracker._retrieve_expired_requests()
|
||||
self.assertEqual(result, {
|
||||
"req_1",
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1")
|
||||
self.tracker.update_done_task_count("req1")
|
||||
self.tracker.update_done_task_count("req1")
|
||||
|
||||
finished = self.tracker.get_and_clear_finished_requests()
|
||||
self.assertEqual(finished, {"req1"})
|
||||
|
||||
|
||||
class TestMooncakeConnectorMetadata(unittest.TestCase):
|
||||
|
||||
def test_add_new_req(self):
|
||||
meta = MooncakeConnectorMetadata()
|
||||
self.assertEqual(len(meta.requests), 0)
|
||||
self.assertEqual(len(meta.requests_to_send), 0)
|
||||
|
||||
meta.add_new_req(request_id="req1",
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id": "remote_engine",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
})
|
||||
|
||||
self.assertEqual(len(meta.requests), 1)
|
||||
req_meta = meta.requests["req1"]
|
||||
self.assertIsInstance(req_meta, ReqMeta)
|
||||
self.assertEqual(req_meta.local_block_ids, [1, 2, 3])
|
||||
self.assertEqual(req_meta.remote_block_ids, [4, 5, 6])
|
||||
self.assertEqual(req_meta.remote_engine_id, "remote_engine")
|
||||
self.assertEqual(req_meta.remote_host, "localhost")
|
||||
self.assertEqual(req_meta.remote_port, 5000)
|
||||
|
||||
|
||||
class TestMooncakeConnectorSchedulerMatchedTokens(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens(self):
|
||||
request = MockRequest("req1")
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertFalse(async_flag)
|
||||
|
||||
request.kv_transfer_params = {"do_remote_prefill": True}
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 3)
|
||||
self.assertTrue(async_flag)
|
||||
|
||||
def test_build_connector_meta(self):
|
||||
request = MockRequest("req1")
|
||||
blocks_mock = MagicMock()
|
||||
blocks_mock.get_unhashed_block_ids.return_value = [4, 5, 6]
|
||||
self.scheduler._reqs_need_recv["req1"] = (request, [4, 5, 6])
|
||||
request.kv_transfer_params = {
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
}
|
||||
|
||||
meta = self.scheduler.build_connector_meta(MagicMock())
|
||||
self.assertIsInstance(meta, MooncakeConnectorMetadata)
|
||||
self.assertEqual(len(meta.requests), 1)
|
||||
self.assertEqual(meta.requests["req1"].local_block_ids, [4, 5, 6])
|
||||
self.assertEqual(meta.requests["req1"].remote_block_ids, [1, 2, 3])
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||
|
||||
def test_get_finished_count(self):
|
||||
count = self.scheduler.get_finished_count()
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
|
||||
class TestHelperFunctions(unittest.TestCase):
|
||||
|
||||
def test_group_concurrent_contiguous(self):
|
||||
src: list[int] = [1, 2, 3, 5, 6]
|
||||
dst: list[int] = [10, 11, 12, 14, 15]
|
||||
|
||||
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
|
||||
|
||||
self.assertEqual(len(src_groups), 2)
|
||||
self.assertEqual(src_groups[0], [1, 2, 3])
|
||||
self.assertEqual(src_groups[1], [5, 6])
|
||||
self.assertEqual(dst_groups[0], [10, 11, 12])
|
||||
self.assertEqual(dst_groups[1], [14, 15])
|
||||
|
||||
def test_group_concurrent_contiguous_empty(self):
|
||||
src: list[int] = []
|
||||
dst: list[int] = []
|
||||
src_groups, dst_groups = group_concurrent_contiguous(src, dst)
|
||||
self.assertEqual(src_groups, [])
|
||||
self.assertEqual(dst_groups, [])
|
||||
|
||||
def test_string_to_int64_hash(self):
|
||||
hash1 = string_to_int64_hash("test_string")
|
||||
hash2 = string_to_int64_hash("test_string")
|
||||
self.assertEqual(hash1, hash2)
|
||||
|
||||
hash3 = string_to_int64_hash("different_string")
|
||||
self.assertNotEqual(hash1, hash3)
|
||||
|
||||
|
||||
class TestMooncakeConnectorForScheduler(unittest.TestCase):
|
||||
|
||||
def test_scheduler_role(self):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_scheduler_methods(self, mock_method):
|
||||
config = MockVllmConfig()
|
||||
connector = MooncakeConnector(config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
|
||||
|
||||
class MockKVCacheBlocks:
|
||||
|
||||
def get_unhashed_block_ids(self):
|
||||
return [4, 5, 6]
|
||||
|
||||
|
||||
class MockSchedulerOutput:
|
||||
pass
|
||||
|
||||
|
||||
class MockForwardContext:
|
||||
pass
|
||||
|
||||
|
||||
class TestMooncakeConnector(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = MockVllmConfig()
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
def test_scheduler_initialization(self):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
self.assertIsNotNone(connector.connector_scheduler)
|
||||
self.assertIsNone(connector.connector_worker)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "get_num_new_matched_tokens")
|
||||
def test_get_num_new_matched_tokens(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.get_num_new_matched_tokens(request, 0)
|
||||
mock_method.assert_called_once_with(request, 0)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "update_state_after_alloc")
|
||||
def test_update_state_after_alloc(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
blocks = MockKVCacheBlocks()
|
||||
connector.update_state_after_alloc(request, blocks, 3)
|
||||
mock_method.assert_called_once_with(request, blocks, 3)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "build_connector_meta")
|
||||
def test_build_connector_meta(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
scheduler_output = MockSchedulerOutput()
|
||||
connector.build_connector_meta(scheduler_output)
|
||||
mock_method.assert_called_once_with(scheduler_output)
|
||||
|
||||
@patch.object(MooncakeConnectorScheduler, "request_finished")
|
||||
def test_request_finished(self, mock_method):
|
||||
connector = MooncakeConnector(self.config, KVConnectorRole.SCHEDULER)
|
||||
request = MockRequest("req1")
|
||||
connector.request_finished(request, [1, 2, 3])
|
||||
mock_method.assert_called_once_with(request, [1, 2, 3])
|
||||
|
||||
|
||||
class TestMooncakeConnectorScheduler(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.config = MockVllmConfig()
|
||||
self.scheduler = MooncakeConnectorScheduler(self.config, "test_engine")
|
||||
|
||||
def test_get_num_new_matched_tokens_no_remote_prefill(self):
|
||||
request = MockRequest("req1")
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 0)
|
||||
self.assertFalse(async_flag)
|
||||
|
||||
def test_get_num_new_matched_tokens_with_remote_prefill(self):
|
||||
request = MockRequest("req1",
|
||||
kv_transfer_params={"do_remote_prefill": True})
|
||||
tokens, async_flag = self.scheduler.get_num_new_matched_tokens(
|
||||
request, 0)
|
||||
self.assertEqual(tokens, 3)
|
||||
self.assertTrue(async_flag)
|
||||
|
||||
def test_update_state_after_alloc_no_remote_prefill(self):
|
||||
request = MockRequest("req1")
|
||||
blocks = MagicMock()
|
||||
self.scheduler.update_state_after_alloc(request, blocks, 0)
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 0)
|
||||
|
||||
def test_update_state_after_alloc_with_remote_prefill(self):
|
||||
request = MockRequest("req1",
|
||||
kv_transfer_params={
|
||||
"do_remote_prefill": True,
|
||||
"remote_block_ids": [1, 2, 3],
|
||||
"remote_engine_id": "remote",
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 5000
|
||||
})
|
||||
blocks = MockKVCacheBlocks()
|
||||
self.scheduler.update_state_after_alloc(request, blocks, 3)
|
||||
self.assertEqual(len(self.scheduler._reqs_need_recv), 1)
|
||||
self.assertEqual(self.scheduler._reqs_need_recv["req1"][0], request)
|
||||
self.assertEqual(self.scheduler._reqs_need_recv["req1"][1], [4, 5, 6])
|
||||
|
||||
def test_request_finished_no_remote_decode(self):
|
||||
request = MockRequest("req1")
|
||||
delay_free, params = self.scheduler.request_finished(
|
||||
request, [1, 2, 3])
|
||||
self.assertFalse(delay_free)
|
||||
self.assertIsNone(params)
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_string_to_int64_hash(self):
|
||||
h1 = string_to_int64_hash("hello")
|
||||
h2 = string_to_int64_hash("hello")
|
||||
h3 = string_to_int64_hash("world")
|
||||
self.assertEqual(h1, h2)
|
||||
self.assertNotEqual(h1, h3)
|
||||
self.assertIsInstance(h1, int)
|
||||
|
||||
def test_group_concurrent_contiguous(self):
|
||||
src: list[int] = [1, 2, 3, 5, 6]
|
||||
dst: list[int] = [10, 11, 12, 20, 21]
|
||||
src_g, dst_g = group_concurrent_contiguous(src, dst)
|
||||
self.assertEqual(src_g, [[1, 2, 3], [5, 6]])
|
||||
self.assertEqual(dst_g, [[10, 11, 12], [20, 21]])
|
||||
|
||||
def test_group_empty(self):
|
||||
src_g, dst_g = group_concurrent_contiguous([], [])
|
||||
self.assertEqual(src_g, [])
|
||||
self.assertEqual(dst_g, [])
|
||||
|
||||
def test_zmq_ctx_invalid_type(self):
|
||||
with self.assertRaises(ValueError):
|
||||
with zmq_ctx("INVALID", "tcp://127.0.0.1:5555"):
|
||||
pass
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.make_zmq_socket")
|
||||
def test_zmq_ctx_ok(self, mock_make_socket):
|
||||
mock_socket = MagicMock()
|
||||
mock_make_socket.return_value = mock_socket
|
||||
with zmq_ctx(zmq.REQ, "tcp://localhost:1234") as s: # type: ignore
|
||||
self.assertEqual(s, mock_socket)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_send_success(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
ensure_zmq_send(mock_socket, b"hello")
|
||||
mock_socket.send.assert_called_once_with(b"hello")
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_send_retry_and_fail(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.send.side_effect = zmq.ZMQError( # type: ignore
|
||||
"send failed")
|
||||
with self.assertRaises(RuntimeError):
|
||||
ensure_zmq_send(mock_socket, b"hello", max_retries=2)
|
||||
self.assertEqual(mock_socket.send.call_count, 2)
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_recv_success(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_socket.recv.return_value = b"response"
|
||||
mock_poller = MagicMock()
|
||||
mock_poller.poll.return_value = [
|
||||
(mock_socket, zmq.POLLIN) # type: ignore
|
||||
]
|
||||
data = ensure_zmq_recv(mock_socket, mock_poller)
|
||||
self.assertEqual(data, b"response")
|
||||
|
||||
@patch("vllm_ascend.distributed.mooncake_connector.logger")
|
||||
def test_ensure_zmq_recv_timeout_and_fail(self, mock_logger):
|
||||
mock_socket = MagicMock()
|
||||
mock_poller = MagicMock()
|
||||
mock_poller.poll.return_value = []
|
||||
with self.assertRaises(RuntimeError):
|
||||
ensure_zmq_recv(mock_socket,
|
||||
mock_poller,
|
||||
timeout=0.01,
|
||||
max_retries=2)
|
||||
|
||||
|
||||
class MockMooncakeAgentMetadata:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class MockMooncakeConnectorMetadata:
|
||||
|
||||
def __init__(self):
|
||||
self.requests = {}
|
||||
|
||||
|
||||
class MockKVCacheSendingThread(threading.Thread):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self._finished_requests = set()
|
||||
|
||||
def get_and_clear_finished_requests(self):
|
||||
return self._finished_requests
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockKVCacheRecvingThread(threading.Thread):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.daemon = True
|
||||
self._finished_requests = set()
|
||||
self.add_request = MagicMock()
|
||||
|
||||
def get_and_clear_finished_requests(self):
|
||||
return self._finished_requests
|
||||
|
||||
def start(self):
|
||||
pass
|
||||
|
||||
|
||||
class MockTensor:
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.size = MagicMock(return_value=(10, 16, 8, 16))
|
||||
self.element_size = MagicMock(return_value=4)
|
||||
self.shape = (10, 16, 8, 16)
|
||||
self.data_ptr = MagicMock(return_value=0x1000)
|
||||
|
||||
|
||||
mock_envs_ascend = MagicMock()
|
||||
mock_envs_ascend.MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
|
||||
|
||||
mock_logger = MagicMock()
|
||||
|
||||
|
||||
class MockTransferEngine:
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
return 0
|
||||
|
||||
def register_memory(self, *args, **kwargs):
|
||||
return 1
|
||||
|
||||
|
||||
class MockEnvsAscend:
|
||||
MOONCAKE_CONNECTOR_PROTOCOL = "mock_protocol"
|
||||
PHYSICAL_DEVICES = "10,11"
|
||||
|
||||
|
||||
def mock_get_tensor_model_parallel_rank():
|
||||
return 0
|
||||
|
||||
|
||||
def mock_get_tp_group():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def mock_get_ip():
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
def mock_string_to_int64_hash(s):
|
||||
return hash(s)
|
||||
|
||||
|
||||
class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.envs_ascend_mock = MockEnvsAscend()
|
||||
self.mock_transfer_engine = MagicMock()
|
||||
self.mock_transfer_engine.get_rpc_port.return_value = 9090
|
||||
self.mock_transfer_engine.initialize.return_value = 0
|
||||
self.mock_transfer_engine.register_memory.return_value = 0
|
||||
|
||||
self.patches = [
|
||||
patch('os.getenv', return_value="10,11"),
|
||||
patch('torch.Tensor.size', return_value=(10, 16, 8, 16)),
|
||||
patch('torch.Tensor.element_size', return_value=4),
|
||||
patch('torch.Tensor.data_ptr', return_value=0x1000),
|
||||
patch('math.prod', return_value=128),
|
||||
patch('random.Random'),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
|
||||
mock_get_tensor_model_parallel_rank),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
||||
mock_get_tp_group),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.get_ip',
|
||||
mock_get_ip),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
|
||||
mock_string_to_int64_hash),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.TransferEngine',
|
||||
return_value=self.mock_transfer_engine),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',
|
||||
MagicMock()),
|
||||
patch(
|
||||
'vllm_ascend.distributed.mooncake_connector.KVCacheRecvingThread',
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.logger',
|
||||
MagicMock()),
|
||||
patch('vllm_ascend.distributed.mooncake_connector.threading.Event',
|
||||
MagicMock()),
|
||||
patch.dict('sys.modules',
|
||||
{'vllm_ascend.envs': self.envs_ascend_mock}),
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start() # type: ignore
|
||||
|
||||
self.vllm_config = MockVllmConfig()
|
||||
self.engine_id = "test_engine"
|
||||
self.kv_caches = {"layer1": (MagicMock(), MagicMock())}
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.patches:
|
||||
p.stop() # type: ignore
|
||||
|
||||
def test_register_kv_caches_producer(self):
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
self.assertEqual(len(worker.kv_caches), 1)
|
||||
self.assertIsNotNone(worker.kv_send_thread)
|
||||
self.assertIsNone(worker.kv_recv_thread)
|
||||
|
||||
def test_register_kv_caches_consumer(self):
|
||||
self.vllm_config.kv_transfer_config.kv_role = 'kv_consumer'
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(self.kv_caches)
|
||||
self.assertIsNone(worker.kv_send_thread)
|
||||
self.assertIsNotNone(worker.kv_recv_thread)
|
||||
|
||||
def test_register_kv_caches_mla_case(self):
|
||||
mla_cache1 = MagicMock()
|
||||
mla_cache1.size.return_value = (10, 16, 1, 16)
|
||||
mla_cache2 = MagicMock()
|
||||
mla_cache2.size.return_value = (10, 16, 1, 8)
|
||||
mla_caches = {"layer1": (mla_cache1, mla_cache2)}
|
||||
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
worker.register_kv_caches(mla_caches)
|
||||
self.assertTrue(worker.use_mla)
|
||||
self.assertEqual(len(worker.block_len), 2)
|
||||
|
||||
def test_device_id_selection_with_physical_devices(self):
|
||||
# Test with physical devices set
|
||||
worker = MooncakeConnectorWorker(self.vllm_config, self.engine_id)
|
||||
# Default tp_rank is 0, so device_id should be 10
|
||||
self.assertEqual(worker.device_id, 10)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
169
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
169
tests/ut/kv_connector/test_remote_decode_lifecycle.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import FinishReason, RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a Remote Decode request."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request = create_request(request_id=1,
|
||||
max_tokens=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Prefill.
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 1
|
||||
|
||||
# (1b): execute_model()
|
||||
model_runner_output = create_model_runner_output(reqs=[request])
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
# Ensure the request is finished after 1 tokens.
|
||||
assert request.is_finished()
|
||||
assert request.status == RequestStatus.FINISHED_LENGTH_CAPPED
|
||||
output = engine_core_outputs[0].outputs[0]
|
||||
assert output.finish_reason == FinishReason.LENGTH
|
||||
assert output.kv_transfer_params is not None
|
||||
|
||||
# Request freed in Scheduler and blocks should be freed
|
||||
assert request_id in scheduler.finished_req_ids
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
|
||||
# ... but blocks should not be freed.
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 1
|
||||
assert request_id in scheduler_output.finished_req_ids
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (2b): execute_model()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (2c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# STEP (3): Finished sending.
|
||||
# (3a): schedule() - pass finished request to PB.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.finished_req_ids) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_id])
|
||||
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Confirm we do not have any memory leaks after req lifecycle.
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_prefix_cache_lifecycle():
|
||||
"""Test that remote decode params still works with a prefix cache hit."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# Prime the KVCache.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 3
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote_a = create_request(request_id=1, num_tokens=NUM_TOKENS)
|
||||
|
||||
scheduler.add_request(request_remote_a)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote_a],
|
||||
use_eos=True)
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
|
||||
#####################
|
||||
# Actual Test: confirm we send all blocks.
|
||||
|
||||
# Step (1): Send the KV Transfer.
|
||||
NUM_EXTERNAL_FULL_BLOCKS -= 1
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
request_remote = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_decode=True)
|
||||
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output(reqs=[request_remote])
|
||||
eco = scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
kv_transfer_params = eco[0].outputs[0].kv_transfer_params
|
||||
# Ensure we send all block ids, even if there is a cache hit.
|
||||
assert (len(
|
||||
kv_transfer_params["remote_block_ids"]) == (NUM_EXTERNAL_FULL_BLOCKS +
|
||||
1))
|
||||
|
||||
# STEP (2): Ensure it is freed.
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=[request_remote.request_id])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
_ = scheduler.schedule()
|
||||
assert_scheduler_empty(scheduler)
|
||||
239
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
239
tests/ut/kv_connector/test_remote_prefill_lifecycle.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
#
|
||||
import copy
|
||||
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
from tests.ut.kv_connector.utils import (assert_scheduler_empty,
|
||||
create_model_runner_output,
|
||||
create_request, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
def test_basic_lifecycle():
|
||||
"""Test lifecycle of a remote prefill."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
START_FREE_BLOCK_QUEUE_SIZE = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
block_size=BLOCK_SIZE)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1):
|
||||
# (1a): schedule()
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# Nothing running and empty scheduler output.
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler_output.scheduled_new_reqs) == 0
|
||||
assert scheduler_output.scheduled_cached_reqs.num_reqs == 0
|
||||
assert len(scheduler_output.num_scheduled_tokens) == 0
|
||||
assert scheduler_output.total_num_scheduled_tokens == 0
|
||||
|
||||
# Req waiting for KVs with no computed/scheduled toks ...
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert request in scheduler.waiting
|
||||
assert (request.status == RequestStatus.WAITING_FOR_REMOTE_KVS)
|
||||
assert (request.num_computed_tokens == 0)
|
||||
|
||||
# ... but should have (uncached) blocks allocated to it.
|
||||
block_pool = scheduler.kv_cache_manager.block_pool
|
||||
assert (block_pool.free_block_queue.num_free_blocks
|
||||
< START_FREE_BLOCK_QUEUE_SIZE)
|
||||
assert len(block_pool.cached_block_hash_to_block) == 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block._block_hash is None
|
||||
|
||||
# (1b): forward()
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
# (1c): update_from_output()
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert not engine_core_outputs or not engine_core_outputs[0].outputs
|
||||
|
||||
# STEP (2):
|
||||
# (2a): schedule(): nothing happens!
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert len(scheduler.running) == 0
|
||||
|
||||
# (2b): forward(): request finishes recv.
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
|
||||
# (2c): update_from_output():
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# STEP (3):
|
||||
# (3a): schedule(): this should actually schedule.
|
||||
scheduler_output = scheduler.schedule()
|
||||
assert len(scheduler.running) == 1
|
||||
|
||||
# Confirm the block are actually allocated.
|
||||
num_hashed_blocks = 0
|
||||
blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_id]
|
||||
for block in blocks:
|
||||
assert block.ref_cnt == 1
|
||||
num_hashed_blocks += (1 if block._block_hash is not None else 0)
|
||||
assert num_hashed_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
|
||||
# Confirm the rest of the prompt is scheduled in this step.
|
||||
scheduled_req = scheduler_output.scheduled_new_reqs[0]
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[request_id]
|
||||
num_computed_tokens = scheduled_req.num_computed_tokens
|
||||
total_prompt_tokens = len(scheduled_req.prompt_token_ids)
|
||||
assert (num_scheduled_tokens == total_prompt_tokens - num_computed_tokens)
|
||||
|
||||
# (3b): execute_model()
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
# (3c): update_from_output()
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
engine_core_outputs = scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
scheduler.schedule()
|
||||
|
||||
assert_scheduler_empty(scheduler)
|
||||
|
||||
|
||||
def test_no_spurious_prefix_caching():
|
||||
"""
|
||||
With P/D, blocks can be allocated but uncomputed for
|
||||
multiple engine steps. This test confirms that we do
|
||||
not accidentally have cache hits against uncomputed
|
||||
blocks.
|
||||
"""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 and a half full external blocks.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * (NUM_EXTERNAL_FULL_BLOCKS + 0.5))
|
||||
|
||||
# Both of these requests have prompts like [1,1,1,1,1, ...]
|
||||
request_remote = create_request(
|
||||
request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True,
|
||||
use_all_1s_for_prompt_tokens=True,
|
||||
)
|
||||
|
||||
# Schedule the remote prefill request. This should not
|
||||
# cause any blocks to be cached.
|
||||
scheduler.add_request(request_remote)
|
||||
scheduler_output = scheduler.schedule()
|
||||
scheduler.update_from_output(scheduler_output, EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
remote_blocks = scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
0].req_to_blocks[request_remote.request_id]
|
||||
|
||||
# Remote blocks should not be cached.
|
||||
for block in remote_blocks:
|
||||
assert block.ref_cnt == 1
|
||||
assert block._block_hash is None
|
||||
|
||||
|
||||
def test_full_block_prompt():
|
||||
"""Test that we handle a prompt that is the full block size."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
scheduler = create_scheduler(vllm_config)
|
||||
|
||||
# 2 Full Blocks and 1 Half Block.
|
||||
BLOCK_SIZE = vllm_config.cache_config.block_size
|
||||
NUM_EXTERNAL_FULL_BLOCKS = 2
|
||||
NUM_TOKENS = int(BLOCK_SIZE * NUM_EXTERNAL_FULL_BLOCKS)
|
||||
|
||||
request = create_request(request_id=1,
|
||||
num_tokens=NUM_TOKENS,
|
||||
do_remote_prefill=True)
|
||||
|
||||
scheduler.add_request(request)
|
||||
request_id = request.request_id
|
||||
|
||||
# STEP (1): Initialize a recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
# All blocks should be allocated.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
model_runner_output = EMPTY_MODEL_RUNNER_OUTPUT
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # STEP (2): Recv.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_recving=[request_id])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
assert len(scheduler.waiting) == 1
|
||||
assert (request_id in scheduler.finished_recving_kv_req_ids)
|
||||
|
||||
# # STEP (3): Run as usual.
|
||||
scheduler_output = scheduler.schedule()
|
||||
|
||||
# We need to recompute the final token of the prompt to generate
|
||||
# the first new token, so we should not have a new block.
|
||||
num_blocks = len(scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks[request_id])
|
||||
assert num_blocks == NUM_EXTERNAL_FULL_BLOCKS
|
||||
assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens ==
|
||||
NUM_TOKENS - 1)
|
||||
assert (scheduler_output.num_scheduled_tokens[request_id] == 1)
|
||||
|
||||
model_runner_output = create_model_runner_output([request])
|
||||
scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
# # Step (4): Hit EOS.
|
||||
scheduler_output = scheduler.schedule()
|
||||
model_runner_output = create_model_runner_output([request], use_eos=True)
|
||||
scheduler.schedule()
|
||||
|
||||
assert_scheduler_empty(scheduler)
|
||||
233
tests/ut/kv_connector/utils.py
Normal file
233
tests/ut/kv_connector/utils.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# This code is from: https://github.com/vllm-project/vllm/tests/v1/kv_connector/unit/utils.py
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import (CacheConfig, DeviceConfig, KVTransferConfig,
|
||||
ModelConfig, SchedulerConfig, VllmConfig)
|
||||
from vllm.v1.core.kv_cache_utils import (get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
os.environ["VLLM_USE_V1"] = "1"
|
||||
|
||||
|
||||
def assert_scheduler_empty(scheduler: Scheduler):
|
||||
"""Confirm the scheduler is "empty" - i.e. no leaks."""
|
||||
# Scheduler Metadata.
|
||||
assert len(scheduler.requests) == 0
|
||||
assert len(scheduler.waiting) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.finished_req_ids) == 0
|
||||
assert len(scheduler.finished_recving_kv_req_ids) == 0
|
||||
|
||||
# EncoderCacheManager.
|
||||
assert len(scheduler.encoder_cache_manager.freed) == 0
|
||||
assert len(scheduler.encoder_cache_manager.cached) == 0
|
||||
|
||||
# KVCache Manager.
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
req_to_blocks) == 0
|
||||
assert len(scheduler.kv_cache_manager.coordinator.single_type_managers[0].
|
||||
num_cached_block) == 0
|
||||
num_free_blocks = (
|
||||
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
|
||||
assert num_free_blocks == (
|
||||
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
|
||||
|
||||
# NOTE(rob): just the ref count on blocks will be 0. The hash
|
||||
# value, etc will remain since we lazily evict for prefix cache.
|
||||
for block in scheduler.kv_cache_manager.block_pool.blocks:
|
||||
assert block.ref_cnt == 0
|
||||
|
||||
|
||||
def create_vllm_config(
|
||||
max_num_seqs: int = 16,
|
||||
max_num_batched_tokens: int = 1024,
|
||||
block_size: int = 128,
|
||||
) -> VllmConfig:
|
||||
"""Initialize VllmConfig For Testing."""
|
||||
scheduler_config = SchedulerConfig(
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
max_model_len=max_num_batched_tokens,
|
||||
)
|
||||
fake_weight_path = os.path.join(os.path.dirname(__file__), "..",
|
||||
"fake_weight")
|
||||
model_config = ModelConfig(
|
||||
model=fake_weight_path,
|
||||
skip_tokenizer_init=True,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
cache_config = CacheConfig(
|
||||
block_size=block_size,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
enable_prefix_caching=True,
|
||||
)
|
||||
kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="LLMDataDistCMgrConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_module_path=
|
||||
"vllm_ascend.distributed.llmdatadist_c_mgr_connector")
|
||||
return VllmConfig(scheduler_config=scheduler_config,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
device_config=DeviceConfig("cpu"))
|
||||
|
||||
|
||||
def create_scheduler(
|
||||
vllm_config: VllmConfig,
|
||||
num_blocks: int = 10000,
|
||||
) -> Scheduler:
|
||||
"""Initialize Scheduler For Testing."""
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=num_blocks, # A large number of blocks to hold all requests
|
||||
kv_cache_tensors=[],
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float16,
|
||||
False))
|
||||
],
|
||||
)
|
||||
vllm_config.cache_config.num_gpu_blocks = num_blocks
|
||||
return Scheduler(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
log_stats=True,
|
||||
structured_output_manager=StructuredOutputManager(vllm_config),
|
||||
)
|
||||
|
||||
|
||||
_none_hash_initialized = False
|
||||
|
||||
|
||||
def create_request(
|
||||
request_id: int,
|
||||
num_tokens: int = 10,
|
||||
max_tokens: int = 128,
|
||||
do_remote_decode: bool = False,
|
||||
do_remote_prefill: bool = False,
|
||||
use_all_1s_for_prompt_tokens: bool = False,
|
||||
num_remote_blocks: int = 3,
|
||||
block_size: int = 16,
|
||||
) -> Request:
|
||||
"""Make dummy request for testing."""
|
||||
global _none_hash_initialized
|
||||
if not _none_hash_initialized:
|
||||
init_none_hash(hash)
|
||||
_none_hash_initialized = True
|
||||
|
||||
block_hasher = get_request_block_hasher(block_size, hash)
|
||||
|
||||
kv_transfer_params: Optional[dict[str, Any]] = None
|
||||
|
||||
if do_remote_decode:
|
||||
assert not do_remote_prefill
|
||||
kv_transfer_params = dict(do_remote_prefill=False,
|
||||
do_remote_decode=True)
|
||||
elif do_remote_prefill:
|
||||
kv_transfer_params = dict(do_remote_prefill=True,
|
||||
do_remote_decode=False,
|
||||
remote_engine_id="my-engine-id",
|
||||
remote_block_ids=list(
|
||||
range(num_remote_blocks)),
|
||||
remote_host="my-host",
|
||||
remote_port=1234,
|
||||
remote_tp_size=1)
|
||||
|
||||
max_tokens = 1 if do_remote_decode else max_tokens
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens)
|
||||
|
||||
if use_all_1s_for_prompt_tokens:
|
||||
prompt_token_ids = [1] * num_tokens
|
||||
else:
|
||||
prompt_token_ids = [i * request_id for i in range(num_tokens)]
|
||||
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
multi_modal_kwargs=None,
|
||||
multi_modal_placeholders=None,
|
||||
multi_modal_hashes=None,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
else:
|
||||
req = Request(
|
||||
request_id=f"id-{request_id}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=[],
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=block_hasher,
|
||||
)
|
||||
req.kv_transfer_params = kv_transfer_params
|
||||
return req
|
||||
|
||||
|
||||
def create_model_runner_output(
|
||||
reqs: list[Request],
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
# Make request data.
|
||||
req_ids = [req.request_id for req in reqs]
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
# Make output data structure.
|
||||
extra_args = {}
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import \
|
||||
KVConnectorOutput # type: ignore # noqa
|
||||
kv_connector_output = KVConnectorOutput(finished_sending=finished_sending,
|
||||
finished_recving=finished_recving)
|
||||
extra_args = {"kv_connector_output": kv_connector_output}
|
||||
if vllm_version_is("0.10.1.1") or vllm_version_is("0.10.1"):
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
model_runner_output = ModelRunnerOutput(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
**extra_args,
|
||||
)
|
||||
|
||||
return model_runner_output
|
||||
0
tests/ut/models/__init__.py
Normal file
0
tests/ut/models/__init__.py
Normal file
195
tests/ut/models/test_deepseek_mtp.py
Normal file
195
tests/ut/models/test_deepseek_mtp.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.models.deepseek_mtp import (
|
||||
CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor,
|
||||
CustomDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
|
||||
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp_layer(self, mocker: MockerFixture):
|
||||
config = PretrainedConfig(vocab_size=1000,
|
||||
hidden_size=768,
|
||||
rms_norm_eps=1e-5)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
|
||||
return_value=None)
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.models.deepseek_v2.CustomDeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
return mtp_layer
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch.object(mtp_layer,
|
||||
'eh_proj',
|
||||
return_value=torch.randn(2, 3, 768))
|
||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
||||
torch.randn(2, 3, 768))
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
kv_cache = torch.randn(2, 3, 768)
|
||||
previous_hidden_states = torch.randn(2, 3, 768)
|
||||
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = mtp_layer(input_ids, positions, kv_cache, None,
|
||||
previous_hidden_states, inputs_embeds, 0)
|
||||
assert output.shape == (2, 3, 768)
|
||||
|
||||
|
||||
class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_predictor(self, mocker: MockerFixture):
|
||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||
mock_model_config = mocker.MagicMock(spec=ModelConfig)
|
||||
mock_hf_config = mocker.MagicMock()
|
||||
mock_hf_config.num_hidden_layers = 12
|
||||
mock_hf_config.num_nextn_predict_layers = 3
|
||||
mock_hf_config.vocab_size = 30000
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
mock_vllm_config.quant_config = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
predictor = CustomDeepSeekMultiTokenPredictor(
|
||||
vllm_config=mock_vllm_config)
|
||||
return predictor
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_predictor):
|
||||
predictor = setup_predictor
|
||||
assert predictor.num_mtp_layers == 3
|
||||
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'kv_caches, inputs_embeds',
|
||||
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
|
||||
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
|
||||
inputs_embeds):
|
||||
predictor = setup_predictor
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
|
||||
# todo: need or not?
|
||||
# predictor.num_mtp_layers = 1
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
output = predictor.forward(input_ids, positions, kv_caches, None, None,
|
||||
inputs_embeds, 0)
|
||||
mock_layer.assert_called_once()
|
||||
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
|
||||
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
predictor = setup_predictor
|
||||
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
|
||||
return_value=None)
|
||||
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
|
||||
result_logits = predictor.compute_logits(hidden_states=hidden_states,
|
||||
sampling_metadata=None)
|
||||
predictor.logits_processor.assert_called_once()
|
||||
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
class TestCustomDeepSeekMTP(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp(self, mocker: MockerFixture):
|
||||
vllm_config = mocker.MagicMock()
|
||||
vllm_config.model_config.hf_config.num_hidden_layers = 12
|
||||
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
|
||||
vllm_config.cache_config = mocker.MagicMock()
|
||||
vllm_config.quant_config = mocker.MagicMock()
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
|
||||
return mtp
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp):
|
||||
mtp = setup_mtp
|
||||
assert isinstance(mtp, CustomDeepSeekMTP)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
||||
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
spec_step_idx = 0
|
||||
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
295
tests/ut/models/test_deepseek_v2.py
Normal file
295
tests/ut/models/test_deepseek_v2.py
Normal file
@@ -0,0 +1,295 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from vllm_ascend.models.deepseek_v2 import (
|
||||
CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention,
|
||||
CustomDeepseekV2MLP, CustomDeepseekV2MoE,
|
||||
CustomDeepseekV2RowParallelLinear,
|
||||
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
CustomDeepseekV2SiluAndMul, LogitsProcessor, ParallelLMHead)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
config = PretrainedConfig(
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-6,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=2048,
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
moe_intermediate_size=256,
|
||||
num_experts_per_tok=2,
|
||||
routed_scaling_factor=1.0,
|
||||
first_k_dense_replace=0,
|
||||
moe_layer_freq=1,
|
||||
kv_lora_rank=16,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="softmax",
|
||||
norm_topk_prob=True,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
vocab_size=10000,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_config(base_config):
|
||||
model_config = SimpleNamespace(
|
||||
hf_config=base_config,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
quant_config=None,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
vllm_config = Mock()
|
||||
vllm_config.model_config = model_config
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.quant_config = None
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
|
||||
pp_group = Mock(spec=GroupCoordinator)
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_ep_group", return_value=ep_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_dp_group", return_value=dp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.models.deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \
|
||||
patch("torch.npu.current_device", return_value=0):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch("vllm_ascend.models.deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
silu = CustomDeepseekV2SiluAndMul()
|
||||
assert silu.weight_scale is None
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
output = silu.forward_oot(x)
|
||||
assert output.shape == (2, 2)
|
||||
|
||||
weight_scale = Mock(return_value=torch.tensor(0.1))
|
||||
silu = CustomDeepseekV2SiluAndMul(weight_scale=weight_scale)
|
||||
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
|
||||
dynamic_scale = torch.randn(2, 1)
|
||||
with patch("torch_npu.npu_dequant_swiglu_quant",
|
||||
return_value=torch.randn(2, 4)):
|
||||
output = silu.forward_oot((quant_x, dynamic_scale))
|
||||
assert output.shape == (2, 4)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_merged_replicated_linear(mock_distributed):
|
||||
linear = CustomDeepseekV2MergedReplicatedLinear(input_size=128,
|
||||
output_sizes=[64, 64],
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
assert linear.output_sizes == [64, 64]
|
||||
|
||||
param = Mock()
|
||||
param.data = torch.zeros(128, 128)
|
||||
param.output_dim = 1
|
||||
param.is_gguf_weight = False
|
||||
param.is_gguf_weight_type = False
|
||||
loaded_weight = torch.randn(128, 64)
|
||||
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [
|
||||
CustomDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
CustomDeepseekV2RowParallelLinear
|
||||
])
|
||||
def test_row_parallel_linear(cls, mock_distributed):
|
||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||
linear.quant_method = Mock()
|
||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||
|
||||
input_ = torch.randn(2, 4, 128)
|
||||
with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim",
|
||||
return_value=[torch.randn(2, 4, 64)]):
|
||||
linear.input_is_parallel = False
|
||||
output = linear(input_, is_prefill=True)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
linear.input_is_parallel = True
|
||||
output = linear(input_, is_prefill=False)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
mlp = CustomDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, CustomDeepseekV2SiluAndMul)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
output = mlp(x)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
with patch("vllm_ascend.models.deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
mock_quant_config.name = "w8a8dynamic"
|
||||
with pytest.raises(NotImplementedError):
|
||||
CustomDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=mock_quant_config,
|
||||
force_replicate=False)
|
||||
with pytest.raises(ValueError):
|
||||
CustomDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="relu",
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_custom_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = CustomDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
prefix="mlp")
|
||||
assert moe.top_k == 2
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=16,
|
||||
kv_lora_rank=16,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None,
|
||||
prefix="layers.0.self_attn")
|
||||
assert attn.debug_layer_idx == 0
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(attn.mla_attn,
|
||||
"__call__",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
with pytest.raises(AssertionError):
|
||||
attn(positions, x)
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
|
||||
|
||||
def test_deepseek_v2_lmhead(mock_distributed, vllm_config):
|
||||
# 创建一个简单的配置对象
|
||||
class SimpleConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.vocab_size = 10000
|
||||
self.hidden_size = 128
|
||||
|
||||
config = SimpleConfig()
|
||||
|
||||
# 直接创建lmhead和logits_processor
|
||||
lmhead = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
# 创建模拟输出
|
||||
mock_output = torch.randn(2, 4, config.hidden_size)
|
||||
mock_logits = torch.randn(2, 4, config.vocab_size)
|
||||
|
||||
# 直接测试logits_processor
|
||||
with patch.object(lmhead.quant_method, "apply", return_value=mock_logits):
|
||||
with patch.object(logits_processor,
|
||||
"_gather_logits",
|
||||
return_value=mock_logits):
|
||||
logits = logits_processor(lmhead, mock_output)
|
||||
assert logits.shape == (2, 4, config.vocab_size)
|
||||
424
tests/ut/models/test_qwen2_5_vl.py
Normal file
424
tests/ut/models/test_qwen2_5_vl.py
Normal file
@@ -0,0 +1,424 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.models.qwen2_5_vl import (
|
||||
AscendQwen2_5_VisionAttention, AscendQwen2_5_VisionBlock,
|
||||
AscendQwen2_5_VisionPatchEmbed, AscendQwen2_5_VisionRotaryEmbedding,
|
||||
AscendQwen2_5_VisionTransformer, AscendQwen2_5_VLForConditionalGeneration)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionAttention(PytestBase):
|
||||
|
||||
def init_attention(
|
||||
self,
|
||||
mocker,
|
||||
embed_dim=1000,
|
||||
num_heads=10,
|
||||
projection_size=100,
|
||||
quant_config=None,
|
||||
prefix="",
|
||||
):
|
||||
mocker_attn = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionAttention.__init__")
|
||||
|
||||
attention = AscendQwen2_5_VisionAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
args, kwargs = mocker_attn.call_args
|
||||
assert args == (embed_dim, num_heads, projection_size, None, "")
|
||||
assert not kwargs
|
||||
attention.num_attention_heads_per_partition = num_heads
|
||||
return attention
|
||||
|
||||
def test_attn_init_should_normal(self, mocker: MockerFixture):
|
||||
embed_dim = 1000
|
||||
num_heads = 10
|
||||
projection_size = 100
|
||||
quant_config = None
|
||||
prefix = ""
|
||||
vit = self.init_attention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
mocker=mocker,
|
||||
)
|
||||
assert vit.embed_dim == 1000
|
||||
assert vit.hidden_size_per_attention_head == 10
|
||||
|
||||
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
|
||||
embed_dim = 1000
|
||||
num_heads = 7
|
||||
projection_size = 100
|
||||
quant_config = None
|
||||
prefix = ""
|
||||
with pytest.raises(AssertionError):
|
||||
# projection_size should divided by num heads
|
||||
self.init_attention(
|
||||
mocker=mocker,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def test_split_qkv(self, mocker: MockerFixture):
|
||||
attention = self.init_attention(mocker=mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
q, k, v = attention.split_qkv(torch.rand((100, 10, 300)))
|
||||
assert q.shape == (100, 10, 10, 10)
|
||||
assert k.shape == (100, 10, 10, 10)
|
||||
assert v.shape == (100, 10, 10, 10)
|
||||
|
||||
def test_attn_forward(self, mocker: MockerFixture):
|
||||
attention = self.init_attention(mocker=mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
|
||||
cu_seqlens = torch.tensor([10, 50, 100])
|
||||
cos = torch.rand((1, 100, 1, 128))
|
||||
sin = torch.rand((1, 100, 1, 128))
|
||||
|
||||
qkv = lambda x: (x, 0) # noqa
|
||||
split_qkv = lambda x: [ #noqa
|
||||
torch.rand((100, 3, 10, 128)) for i in range(3)
|
||||
] # noqa
|
||||
npu_rotary_mul = lambda q, cos, sin: q # noqa
|
||||
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
|
||||
proj = lambda x: (x, 0) # noqa
|
||||
|
||||
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
|
||||
mocker_split_qkv = mocker.patch.object(
|
||||
attention,
|
||||
"split_qkv",
|
||||
side_effect=split_qkv,
|
||||
)
|
||||
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
|
||||
side_effect=npu_rotary_mul)
|
||||
mocker_npu_flash_attention_unpad = mocker.patch(
|
||||
"torch_npu._npu_flash_attention_unpad",
|
||||
side_effect=_npu_flash_attention_unpad,
|
||||
)
|
||||
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
|
||||
attention.__dict__["qkv"] = mocker_qkv
|
||||
attention.__dict__["split_qkv"] = mocker_split_qkv
|
||||
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
|
||||
attention.__dict__["_npu_flash_attention_unpad"] = (
|
||||
mocker_npu_flash_attention_unpad)
|
||||
attention.__dict__["proj"] = mocker_proj
|
||||
|
||||
output = attention.forward(
|
||||
x=x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
qkv_args, qkv_kwargs = mocker_qkv.call_args
|
||||
assert qkv_args == (x, )
|
||||
assert not qkv_kwargs
|
||||
|
||||
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
|
||||
assert split_qkv_args == (x, )
|
||||
assert not split_qkv_kwargs
|
||||
|
||||
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
|
||||
assert npu_rotary_mul_args[1:] == (cos, sin)
|
||||
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
|
||||
assert not npu_rotary_mul_kwargs
|
||||
|
||||
assert output.shape == torch.Size([100, 3, 1280])
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionBlock(PytestBase):
|
||||
|
||||
def init_vision_block(
|
||||
self,
|
||||
mocker,
|
||||
dim=100,
|
||||
num_heads=10,
|
||||
mlp_hidden_dim=100,
|
||||
):
|
||||
mocker_vit = mocker.patch(
|
||||
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mocker_attn = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionAttention.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
vision_block = AscendQwen2_5_VisionBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_hidden_dim=mlp_hidden_dim,
|
||||
)
|
||||
args, kwargs = mocker_vit.call_args
|
||||
assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "")
|
||||
assert not kwargs
|
||||
|
||||
args1, kwargs1 = mocker_attn.call_args
|
||||
assert not args1
|
||||
assert kwargs1 == {
|
||||
"embed_dim": dim,
|
||||
"num_heads": num_heads,
|
||||
"projection_size": dim,
|
||||
"quant_config": None,
|
||||
"prefix": ".attn",
|
||||
}
|
||||
return vision_block
|
||||
|
||||
def test_init_vision_block_should_normal(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
vision_block = self.init_vision_block(mocker)
|
||||
assert isinstance(vision_block, AscendQwen2_5_VisionBlock)
|
||||
|
||||
def test_vision_block_forward(self, mocker: MockerFixture):
|
||||
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
|
||||
cu_seqlens = torch.tensor([10, 50, 100])
|
||||
cos = torch.rand((1, 100, 1, 128))
|
||||
sin = torch.rand((1, 100, 1, 128))
|
||||
vision_block = self.init_vision_block(mocker)
|
||||
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
|
||||
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
|
||||
vision_block.__dict__["attn"] = mocker_attn
|
||||
vision_block.__dict__["mlp"] = mocker_mlp
|
||||
|
||||
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
|
||||
|
||||
_, attn_kwargs = mocker_attn.call_args
|
||||
assert attn_kwargs == {
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"cos": cos,
|
||||
"sin": sin,
|
||||
}
|
||||
|
||||
assert torch.all(x * 3 == output)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionPatchEmbed(PytestBase):
|
||||
|
||||
def test_forward(self):
|
||||
patch_embed = AscendQwen2_5_VisionPatchEmbed()
|
||||
|
||||
ret = patch_embed(torch.rand((120, 1176)))
|
||||
assert ret.shape == (120, 1152)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionRotaryEmbedding(PytestBase):
|
||||
|
||||
def init_rotary_embedding(
|
||||
self,
|
||||
mocker,
|
||||
dim=128,
|
||||
):
|
||||
mocker_ebed = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.Qwen2_5_VisionRotaryEmbedding.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
rotary_embedding = AscendQwen2_5_VisionRotaryEmbedding(dim=dim, )
|
||||
args, kwargs = mocker_ebed.call_args
|
||||
assert args == (dim, 10000.0)
|
||||
assert not kwargs
|
||||
return rotary_embedding
|
||||
|
||||
def test_init_rotary_embedding_should_normal(self, mocker: MockerFixture):
|
||||
rotary_embedding = self.init_rotary_embedding(mocker)
|
||||
assert isinstance(rotary_embedding,
|
||||
AscendQwen2_5_VisionRotaryEmbedding)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionTransformer(PytestBase):
|
||||
|
||||
input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
def init_vision_transformer(
|
||||
self,
|
||||
mocker,
|
||||
):
|
||||
norm_eps = 1e-6
|
||||
vision_config = mocker.MagicMock()
|
||||
vision_config.patch_size = 16
|
||||
vision_config.temporal_patch_size = 2
|
||||
vision_config.in_channels = 3
|
||||
vision_config.hidden_act = "gelu"
|
||||
vision_config.depth = 0
|
||||
vision_config.num_heads = 10
|
||||
vision_config.hidden_size = 300
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch("vllm.distributed.utils.divide", return_value=100)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.linear.get_tensor_model_parallel_world_size",
|
||||
return_value=2,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.linear.divide",
|
||||
return_value=2,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.linear.get_tensor_model_parallel_rank",
|
||||
return_value=0)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.parallel_state.get_tensor_model_parallel_world_size",
|
||||
return_value=2,
|
||||
)
|
||||
|
||||
vision_transformer = AscendQwen2_5_VisionTransformer(
|
||||
vision_config,
|
||||
norm_eps,
|
||||
)
|
||||
|
||||
assert not vision_transformer.interleaved
|
||||
return vision_transformer
|
||||
|
||||
def test_init_vision_transformer(self, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
assert isinstance(vision_transformer, AscendQwen2_5_VisionTransformer)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interleaved, expected",
|
||||
[
|
||||
(
|
||||
False,
|
||||
torch.tensor([
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
]),
|
||||
),
|
||||
(
|
||||
True,
|
||||
torch.tensor([
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
vision_transformer.__dict__["interleaved"] = interleaved
|
||||
vision_transformer.__dict__["hidden_size_per_attention_head"] = 2
|
||||
vision_transformer.hidden_size_per_attention_head = 4
|
||||
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
|
||||
assert cos_new.shape == (1, 32, 1, 2)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
grid_thw = torch.tensor([[1, 4, 4]])
|
||||
mocker_patch_embed = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"patch_embed",
|
||||
side_effect=lambda _: torch.randn(16, 512), # noqa
|
||||
)
|
||||
mocker_rot_pos_emb = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"rot_pos_emb",
|
||||
side_effect=lambda _: torch.randn(16, 64), # noqa
|
||||
)
|
||||
mocker_get_window_index = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"get_window_index",
|
||||
side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa
|
||||
)
|
||||
mocker_cal_cos_sin = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"cal_cos_sin",
|
||||
side_effect=lambda _:
|
||||
(torch.randn(16, 32), torch.randn(16, 32)), # noqa
|
||||
)
|
||||
mocker_merger = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"merger",
|
||||
side_effect=lambda _: torch.randn(16, 256), # noqa
|
||||
)
|
||||
vision_transformer.__dict__["vision_blocks"] = [
|
||||
lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa
|
||||
]
|
||||
vision_transformer.__dict__["patch_embed"] = mocker_patch_embed
|
||||
vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb
|
||||
vision_transformer.__dict__[
|
||||
"get_window_index"] = mocker_get_window_index
|
||||
vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin
|
||||
vision_transformer.__dict__["merger"] = mocker_merger
|
||||
vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2]
|
||||
vision_transformer.__dict__["spatial_merge_unit"] = 2
|
||||
ret = vision_transformer.forward(x, grid_thw)
|
||||
assert ret.shape == (8, 256)
|
||||
mocker_patch_embed.assert_called_with(x)
|
||||
mocker_rot_pos_emb.assert_called_with(grid_thw)
|
||||
mocker_get_window_index.assert_called_with(grid_thw)
|
||||
mocker_cal_cos_sin.assert_called_once()
|
||||
mocker_merger.assert_called_once()
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VLForConditionalGeneration(PytestBase):
|
||||
|
||||
def test_init_vl_for_conditional_generation(self, mocker: MockerFixture):
|
||||
vllm_config = mocker.MagicMock()
|
||||
vllm_config.vision_config = "vision_config"
|
||||
vllm_config.rms_norm_eps = 1e-5
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker_vl = mocker.patch(
|
||||
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker_vit = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionTransformer.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration(
|
||||
vllm_config=vllm_config)
|
||||
args, kwargs = mocker_vl.call_args
|
||||
assert not args
|
||||
assert kwargs == {"vllm_config": vllm_config, "prefix": ""}
|
||||
mocker_vit.assert_called_once()
|
||||
assert isinstance(
|
||||
vl_for_conditional_generation,
|
||||
AscendQwen2_5_VLForConditionalGeneration,
|
||||
)
|
||||
422
tests/ut/models/test_qwen2_5_vl_without_padding.py
Normal file
422
tests/ut/models/test_qwen2_5_vl_without_padding.py
Normal file
@@ -0,0 +1,422 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.models.qwen2_5_vl import \
|
||||
Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.models.qwen2_5_vl_without_padding import (
|
||||
AscendQwen2_5_VisionAttention_Without_Padding,
|
||||
AscendQwen2_5_VisionBlock_Without_Padding,
|
||||
AscendQwen2_5_VisionPatchEmbed_Without_Padding,
|
||||
AscendQwen2_5_VisionTransformer_Without_Padding,
|
||||
AscendQwen2_5_VLForConditionalGeneration_Without_Padding)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionAttention_Without_Padding(PytestBase):
|
||||
|
||||
def init_attention(
|
||||
self,
|
||||
mocker,
|
||||
embed_dim=1000,
|
||||
num_heads=10,
|
||||
projection_size=100,
|
||||
quant_config=None,
|
||||
prefix="",
|
||||
):
|
||||
mocker_attn = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.Qwen2_5_VisionAttention.__init__"
|
||||
)
|
||||
|
||||
attention = AscendQwen2_5_VisionAttention_Without_Padding(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
args, kwargs = mocker_attn.call_args
|
||||
assert args == (embed_dim, num_heads, projection_size, None, "")
|
||||
assert not kwargs
|
||||
attention.num_attention_heads_per_partition = num_heads
|
||||
return attention
|
||||
|
||||
def test_vit_init_should_normal(self, mocker: MockerFixture):
|
||||
embed_dim = 1000
|
||||
num_heads = 10
|
||||
projection_size = 100
|
||||
quant_config = None
|
||||
prefix = ""
|
||||
vit = self.init_attention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
mocker=mocker,
|
||||
)
|
||||
assert vit.embed_dim == 1000
|
||||
assert vit.hidden_size_per_attention_head == 10
|
||||
|
||||
def test_vit_init_should_raise_error(self, mocker: MockerFixture):
|
||||
embed_dim = 1000
|
||||
num_heads = 7
|
||||
projection_size = 100
|
||||
quant_config = None
|
||||
prefix = ""
|
||||
with pytest.raises(AssertionError):
|
||||
# projection_size should divided by num heads
|
||||
self.init_attention(
|
||||
mocker=mocker,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def test_vit_forward(self, mocker: MockerFixture):
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
attention = self.init_attention(mocker=mocker)
|
||||
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
|
||||
cu_seqlens = torch.tensor([10, 50, 100])
|
||||
cos = torch.rand((1, 100, 1, 128))
|
||||
sin = torch.rand((1, 100, 1, 128))
|
||||
|
||||
qkv = lambda x: (x, 0) # noqa
|
||||
split_qkv = lambda x: [ #noqa
|
||||
torch.rand((100, 3, 10, 128)) for i in range(3)
|
||||
] # noqa
|
||||
npu_rotary_mul = lambda q, cos, sin: q # noqa
|
||||
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
|
||||
proj = lambda x: (x, 0) # noqa
|
||||
|
||||
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
|
||||
mocker_split_qkv = mocker.patch.object(
|
||||
attention,
|
||||
"split_qkv",
|
||||
side_effect=split_qkv,
|
||||
)
|
||||
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
|
||||
side_effect=npu_rotary_mul)
|
||||
mocker_npu_flash_attention_unpad = mocker.patch(
|
||||
"torch_npu._npu_flash_attention_unpad",
|
||||
side_effect=_npu_flash_attention_unpad,
|
||||
)
|
||||
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
|
||||
attention.__dict__["qkv"] = mocker_qkv
|
||||
attention.__dict__["split_qkv"] = mocker_split_qkv
|
||||
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
|
||||
attention.__dict__["_npu_flash_attention_unpad"] = (
|
||||
mocker_npu_flash_attention_unpad)
|
||||
attention.__dict__["proj"] = mocker_proj
|
||||
|
||||
output = attention.forward(
|
||||
x=x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
qkv_args, qkv_kwargs = mocker_qkv.call_args
|
||||
assert qkv_args == (x, )
|
||||
assert not qkv_kwargs
|
||||
|
||||
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
|
||||
assert split_qkv_args == (x, )
|
||||
assert not split_qkv_kwargs
|
||||
|
||||
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
|
||||
assert npu_rotary_mul_args[1:] == (cos, sin)
|
||||
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
|
||||
assert not npu_rotary_mul_kwargs
|
||||
|
||||
assert output.shape == torch.Size([100, 3, 1280])
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionBlock_Without_Padding(PytestBase):
|
||||
|
||||
def init_vision_block(
|
||||
self,
|
||||
mocker,
|
||||
dim=100,
|
||||
num_heads=10,
|
||||
mlp_hidden_dim=100,
|
||||
):
|
||||
mocker_vit = mocker.patch(
|
||||
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionBlock.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mocker_attn = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionAttention_Without_Padding.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
vision_block = AscendQwen2_5_VisionBlock_Without_Padding(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_hidden_dim=mlp_hidden_dim,
|
||||
)
|
||||
args, kwargs = mocker_vit.call_args
|
||||
assert args == (dim, num_heads, mlp_hidden_dim, F.silu, None, None, "")
|
||||
assert not kwargs
|
||||
|
||||
args1, kwargs1 = mocker_attn.call_args
|
||||
assert not args1
|
||||
assert kwargs1 == {
|
||||
"embed_dim": dim,
|
||||
"num_heads": num_heads,
|
||||
"projection_size": dim,
|
||||
"quant_config": None,
|
||||
"prefix": ".attn",
|
||||
}
|
||||
return vision_block
|
||||
|
||||
def test_init_vision_block_should_normal(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
vision_block = self.init_vision_block(mocker)
|
||||
assert isinstance(vision_block,
|
||||
AscendQwen2_5_VisionBlock_Without_Padding)
|
||||
|
||||
def test_vision_block_forward(self, mocker: MockerFixture):
|
||||
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
|
||||
cu_seqlens = torch.tensor([10, 50, 100])
|
||||
cos = torch.rand((1, 100, 1, 128))
|
||||
sin = torch.rand((1, 100, 1, 128))
|
||||
vision_block = self.init_vision_block(mocker)
|
||||
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
|
||||
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
|
||||
vision_block.__dict__["attn"] = mocker_attn
|
||||
vision_block.__dict__["mlp"] = mocker_mlp
|
||||
|
||||
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
|
||||
|
||||
_, attn_kwargs = mocker_attn.call_args
|
||||
assert attn_kwargs == {
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"cos": cos,
|
||||
"sin": sin,
|
||||
}
|
||||
|
||||
assert torch.all(x * 3 == output)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionPatchEmbed_Without_Padding(PytestBase):
|
||||
|
||||
def test_forward(self):
|
||||
patch_embed = AscendQwen2_5_VisionPatchEmbed_Without_Padding()
|
||||
|
||||
ret = patch_embed(torch.rand((120, 1176)))
|
||||
assert ret.shape == (120, 1152)
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VisionTransformer_Without_Padding(PytestBase):
|
||||
|
||||
input_data = torch.tensor([[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
def init_vision_transformer(
|
||||
self,
|
||||
mocker,
|
||||
):
|
||||
norm_eps = 1e-6
|
||||
vision_config = mocker.MagicMock()
|
||||
vision_config.patch_size = 16
|
||||
vision_config.temporal_patch_size = 2
|
||||
vision_config.in_channels = 3
|
||||
vision_config.hidden_act = "gelu"
|
||||
vision_config.depth = 0
|
||||
vision_config.hidden_size = 1280
|
||||
vision_config.num_heads = 16
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker_vit = mocker.patch(
|
||||
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VisionTransformer.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker_vision_rotary_embedding = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl.AscendQwen2_5_VisionRotaryEmbedding.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionBlock_Without_Padding.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionPatchEmbed_Without_Padding.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_world_size",
|
||||
return_value=1,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.parallel_state.get_tensor_model_parallel_rank",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch("vllm.distributed.utils.divide", return_value=100)
|
||||
|
||||
vision_transformer = AscendQwen2_5_VisionTransformer_Without_Padding(
|
||||
vision_config,
|
||||
norm_eps,
|
||||
)
|
||||
args, kwargs = mocker_vit.call_args
|
||||
assert args == (vision_config, norm_eps, None, "")
|
||||
assert not kwargs
|
||||
mocker_vision_rotary_embedding.assert_called_once()
|
||||
return vision_transformer
|
||||
|
||||
def test_init_vision_transformer(self, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
assert isinstance(vision_transformer,
|
||||
AscendQwen2_5_VisionTransformer_Without_Padding)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"interleaved, expected",
|
||||
[
|
||||
(
|
||||
False,
|
||||
torch.tensor([
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
]),
|
||||
),
|
||||
(
|
||||
True,
|
||||
torch.tensor([
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 0].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[0, 1].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 0].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
input_data[1, 1].cos(),
|
||||
]),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_cal_cos_sin(self, interleaved, expected, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
vision_transformer.__dict__["interleaved"] = interleaved
|
||||
vision_transformer.__dict__["hidden_size_per_attention_head"] = 2
|
||||
vision_transformer.hidden_size_per_attention_head = 4
|
||||
cos_new, _ = vision_transformer.cal_cos_sin(self.input_data)
|
||||
assert cos_new.shape == (1, 4, 1, 2)
|
||||
assert torch.allclose(cos_new.view(-1), expected)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture):
|
||||
vision_transformer = self.init_vision_transformer(mocker)
|
||||
x = torch.randn(1, 3, 224, 224)
|
||||
grid_thw = torch.tensor([[1, 4, 4]])
|
||||
mocker_patch_embed = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"patch_embed",
|
||||
side_effect=lambda _: torch.randn(16, 512), # noqa
|
||||
)
|
||||
mocker_rot_pos_emb = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"rot_pos_emb",
|
||||
side_effect=lambda _: torch.randn(16, 64), # noqa
|
||||
)
|
||||
mocker_get_window_index = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"get_window_index",
|
||||
side_effect=lambda _: (torch.arange(8), [4, 8, 12, 16]), # noqa
|
||||
)
|
||||
mocker_cal_cos_sin = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"cal_cos_sin",
|
||||
side_effect=lambda _:
|
||||
(torch.randn(16, 32), torch.randn(16, 32)), # noqa
|
||||
)
|
||||
mocker_merger = mocker.patch.object(
|
||||
vision_transformer,
|
||||
"merger",
|
||||
side_effect=lambda _: torch.randn(16, 256), # noqa
|
||||
)
|
||||
vision_transformer.__dict__["vision_blocks"] = [
|
||||
lambda *args, **kwargs: torch.randn(16, 1, 512) # noqa
|
||||
]
|
||||
vision_transformer.__dict__["patch_embed"] = mocker_patch_embed
|
||||
vision_transformer.__dict__["rot_pos_emb"] = mocker_rot_pos_emb
|
||||
vision_transformer.__dict__[
|
||||
"get_window_index"] = mocker_get_window_index
|
||||
vision_transformer.__dict__["cal_cos_sin"] = mocker_cal_cos_sin
|
||||
vision_transformer.__dict__["merger"] = mocker_merger
|
||||
vision_transformer.__dict__["fullatt_block_indexes"] = [0, 2]
|
||||
vision_transformer.__dict__["spatial_merge_unit"] = 2
|
||||
ret = vision_transformer.forward(x, grid_thw)
|
||||
assert ret.shape == (8, 256)
|
||||
mocker_patch_embed.assert_called_with(x)
|
||||
mocker_rot_pos_emb.assert_called_with(grid_thw)
|
||||
mocker_get_window_index.assert_called_with(grid_thw)
|
||||
mocker_cal_cos_sin.assert_called_once()
|
||||
mocker_merger.assert_called_once()
|
||||
|
||||
|
||||
class TestAscendQwen2_5_VLForConditionalGeneration_Without_Padding(PytestBase):
|
||||
|
||||
def test_init_vl_for_conditional_generation(self, mocker: MockerFixture):
|
||||
vllm_config = mocker.MagicMock()
|
||||
vllm_config.vision_config = "vision_config"
|
||||
vllm_config.rms_norm_eps = 1e-5
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker_vl = mocker.patch(
|
||||
"vllm.model_executor.models.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker_vit = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_5_vl_without_padding.AscendQwen2_5_VisionTransformer_Without_Padding.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
vl_for_conditional_generation = AscendQwen2_5_VLForConditionalGeneration_Without_Padding(
|
||||
vllm_config=vllm_config)
|
||||
args, kwargs = mocker_vl.call_args
|
||||
assert not args
|
||||
assert kwargs == {"vllm_config": vllm_config, "prefix": ""}
|
||||
mocker_vit.assert_called_once()
|
||||
assert isinstance(
|
||||
vl_for_conditional_generation,
|
||||
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
|
||||
)
|
||||
|
||||
def test_overridden_methods(self):
|
||||
self.assert_method_overridden(
|
||||
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
"_process_image_input",
|
||||
)
|
||||
|
||||
self.assert_method_overridden(
|
||||
AscendQwen2_5_VLForConditionalGeneration_Without_Padding,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
"_process_video_input",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def assert_method_overridden(subclass, parent, method_name: str):
|
||||
"""assert subclass override parent method"""
|
||||
parent_func = parent.__dict__.get(method_name)
|
||||
child_func = subclass.__dict__.get(method_name)
|
||||
|
||||
assert child_func is not None, f"{subclass.__name__} should defined {method_name}"
|
||||
assert child_func is not parent_func, f"{method_name} should override in {subclass.__name__}"
|
||||
200
tests/ut/models/test_qwen2_vl.py
Normal file
200
tests/ut/models/test_qwen2_vl.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.models.qwen2_vl import (AscendQwen2VisionAttention,
|
||||
AscendQwen2VisionBlock)
|
||||
|
||||
|
||||
class TestAscendQwen2VisionAttention(PytestBase):
|
||||
|
||||
def init_attention(
|
||||
self,
|
||||
mocker,
|
||||
embed_dim=1000,
|
||||
num_heads=10,
|
||||
projection_size=100,
|
||||
quant_config=None,
|
||||
prefix="",
|
||||
):
|
||||
mocker_attn = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_vl.Qwen2VisionAttention.__init__")
|
||||
|
||||
attention = AscendQwen2VisionAttention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
args, kwargs = mocker_attn.call_args
|
||||
assert args == (embed_dim, num_heads, projection_size, None, "")
|
||||
assert not kwargs
|
||||
attention.num_attention_heads_per_partition = num_heads
|
||||
return attention
|
||||
|
||||
def test_attn_init_should_normal(self, mocker: MockerFixture):
|
||||
embed_dim = 1000
|
||||
num_heads = 10
|
||||
projection_size = 100
|
||||
quant_config = None
|
||||
prefix = ""
|
||||
vit = self.init_attention(
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
mocker=mocker,
|
||||
)
|
||||
assert vit.hidden_size_per_attention_head == 10
|
||||
|
||||
def test_attn_init_should_raise_error(self, mocker: MockerFixture):
|
||||
embed_dim = 1000
|
||||
num_heads = 7
|
||||
projection_size = 100
|
||||
quant_config = None
|
||||
prefix = ""
|
||||
with pytest.raises(AssertionError):
|
||||
# projection_size should divided by num heads
|
||||
self.init_attention(
|
||||
mocker=mocker,
|
||||
embed_dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
projection_size=projection_size,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
def test_attn_forward(self, mocker: MockerFixture):
|
||||
attention = self.init_attention(mocker=mocker)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
x = torch.rand((100, 3, 10 * 3 * 128)) # s,b, head*3*head_dim
|
||||
cu_seqlens = torch.tensor([10, 50, 100])
|
||||
cos = torch.rand((1, 100, 1, 128))
|
||||
sin = torch.rand((1, 100, 1, 128))
|
||||
|
||||
qkv = lambda x: (x, 0) # noqa
|
||||
split_qkv = lambda x: [ #noqa
|
||||
torch.rand((100, 3, 10, 128)) for i in range(3)
|
||||
] # noqa
|
||||
npu_rotary_mul = lambda q, cos, sin: q # noqa
|
||||
_npu_flash_attention_unpad = lambda **kwargs: kwargs["out"] # noqa
|
||||
proj = lambda x: (x, 0) # noqa
|
||||
|
||||
mocker_qkv = mocker.patch.object(attention, "qkv", side_effect=qkv)
|
||||
mocker_split_qkv = mocker.patch.object(
|
||||
attention,
|
||||
"split_qkv",
|
||||
side_effect=split_qkv,
|
||||
)
|
||||
mocker_npu_rotary_mul = mocker.patch("torch_npu.npu_rotary_mul",
|
||||
side_effect=npu_rotary_mul)
|
||||
mocker_npu_flash_attention_unpad = mocker.patch(
|
||||
"torch_npu._npu_flash_attention_unpad",
|
||||
side_effect=_npu_flash_attention_unpad,
|
||||
)
|
||||
mocker_proj = mocker.patch.object(attention, "proj", side_effect=proj)
|
||||
attention.__dict__["qkv"] = mocker_qkv
|
||||
attention.__dict__["split_qkv"] = mocker_split_qkv
|
||||
attention.__dict__["npu_rotary_mul"] = mocker_npu_rotary_mul
|
||||
attention.__dict__["_npu_flash_attention_unpad"] = (
|
||||
mocker_npu_flash_attention_unpad)
|
||||
attention.__dict__["proj"] = mocker_proj
|
||||
|
||||
output = attention.forward(
|
||||
x=x,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
qkv_args, qkv_kwargs = mocker_qkv.call_args
|
||||
assert qkv_args == (x, )
|
||||
assert not qkv_kwargs
|
||||
|
||||
split_qkv_args, split_qkv_kwargs = mocker_split_qkv.call_args
|
||||
assert split_qkv_args == (x, )
|
||||
assert not split_qkv_kwargs
|
||||
|
||||
npu_rotary_mul_args, npu_rotary_mul_kwargs = mocker_npu_rotary_mul.call_args
|
||||
assert npu_rotary_mul_args[1:] == (cos, sin)
|
||||
assert npu_rotary_mul_args[0].shape == torch.Size([3, 100, 10, 128])
|
||||
assert not npu_rotary_mul_kwargs
|
||||
|
||||
assert output.shape == torch.Size([100, 3, 1280])
|
||||
|
||||
|
||||
class TestAscendQwen2VisionBlock(PytestBase):
|
||||
|
||||
def init_vision_block(
|
||||
self,
|
||||
mocker,
|
||||
dim=100,
|
||||
num_heads=10,
|
||||
mlp_ratio=0.5,
|
||||
):
|
||||
mocker_vit = mocker.patch(
|
||||
"vllm.model_executor.models.qwen2_vl.Qwen2VisionBlock.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mocker_attn = mocker.patch(
|
||||
"vllm_ascend.models.qwen2_vl.AscendQwen2VisionAttention.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
vision_block = AscendQwen2VisionBlock(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
args, kwargs = mocker_vit.call_args
|
||||
assert args == (dim, num_heads, mlp_ratio, QuickGELU, None, None, "")
|
||||
assert not kwargs
|
||||
|
||||
args1, kwargs1 = mocker_attn.call_args
|
||||
assert not args1
|
||||
assert kwargs1 == {
|
||||
"embed_dim": dim,
|
||||
"num_heads": num_heads,
|
||||
"projection_size": dim,
|
||||
"quant_config": None,
|
||||
"prefix": ".attn",
|
||||
}
|
||||
return vision_block
|
||||
|
||||
def test_init_vision_block_should_normal(
|
||||
self,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
vision_block = self.init_vision_block(mocker)
|
||||
assert isinstance(vision_block, AscendQwen2VisionBlock)
|
||||
|
||||
def test_vision_block_forward(self, mocker: MockerFixture):
|
||||
x = torch.randint(1, 100, (100, 3, 1280)) # s,b,d
|
||||
cu_seqlens = torch.tensor([10, 50, 100])
|
||||
cos = torch.rand((1, 100, 1, 128))
|
||||
sin = torch.rand((1, 100, 1, 128))
|
||||
vision_block = self.init_vision_block(mocker)
|
||||
mocker_attn = mocker.patch.object(vision_block, "attn", return_value=x)
|
||||
mocker_mlp = mocker.patch.object(vision_block, "mlp", return_value=x)
|
||||
vision_block.__dict__["attn"] = mocker_attn
|
||||
vision_block.__dict__["mlp"] = mocker_mlp
|
||||
|
||||
output = vision_block.forward(x.clone(), cu_seqlens, cos, sin)
|
||||
|
||||
_, attn_kwargs = mocker_attn.call_args
|
||||
assert attn_kwargs == {
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"cos": cos,
|
||||
"sin": sin,
|
||||
}
|
||||
|
||||
assert torch.all(x * 3 == output)
|
||||
98
tests/ut/models/test_qwen3_moe.py
Normal file
98
tests/ut/models/test_qwen3_moe.py
Normal file
@@ -0,0 +1,98 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import math
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||
|
||||
from vllm_ascend.models.qwen3_moe import CustomQwen3MoeForCausalLM
|
||||
from vllm_ascend.torchair.models.qwen3_moe import CustomQwen3MoeAttention
|
||||
|
||||
|
||||
class TestCustomQwen3MoeForCausalLM:
|
||||
|
||||
def test_class_inheritance(self):
|
||||
assert issubclass(CustomQwen3MoeForCausalLM, Qwen3MoeForCausalLM)
|
||||
|
||||
@pytest.mark.parametrize("key, expected", [
|
||||
("qkv_proj", ["q_proj", "k_proj", "v_proj"]),
|
||||
("gate_up_proj", ["gate_proj", "up_proj"]),
|
||||
("experts",
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]),
|
||||
])
|
||||
def test_packed_modules_mapping(self, key, expected):
|
||||
assert CustomQwen3MoeForCausalLM.packed_modules_mapping[
|
||||
key] == expected
|
||||
|
||||
def test_packed_modules_mapping_structure(self):
|
||||
expected_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts": [
|
||||
"experts.0.gate_proj", "experts.0.up_proj",
|
||||
"experts.0.down_proj"
|
||||
]
|
||||
}
|
||||
assert CustomQwen3MoeForCausalLM.packed_modules_mapping == expected_mapping
|
||||
|
||||
|
||||
class DummyRMSNorm:
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
denom = (mean_sq + self.eps).sqrt()
|
||||
return x / denom
|
||||
|
||||
|
||||
class TestCustomQwen3MoeAttention(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.batch = 2
|
||||
self.seq_len = 3
|
||||
self.q_size = 8
|
||||
self.kv_size = 8
|
||||
self.head_dim = 4
|
||||
self.rms_eps = 1e-6
|
||||
|
||||
total_dim = self.q_size + 2 * self.kv_size
|
||||
|
||||
self.qkv = torch.arange(self.batch * self.seq_len * total_dim,
|
||||
dtype=torch.float32).reshape(
|
||||
self.batch, self.seq_len, total_dim)
|
||||
|
||||
def test_constant_input_normalization(self):
|
||||
ones_qkv = torch.ones((1, 1, self.q_size + 2 * self.kv_size),
|
||||
dtype=torch.float32)
|
||||
|
||||
q_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
|
||||
k_norm = DummyRMSNorm(self.head_dim, self.rms_eps)
|
||||
q, k, v = CustomQwen3MoeAttention.normalize_qkv(
|
||||
ones_qkv, self.q_size, self.kv_size, self.head_dim, q_norm, k_norm)
|
||||
|
||||
norm_val = 1.0 / math.sqrt(1.0 + self.rms_eps)
|
||||
|
||||
expected_q = torch.full((1, 1, self.q_size), norm_val)
|
||||
expected_k = torch.full((1, 1, self.kv_size), norm_val)
|
||||
expected_v = torch.ones((1, 1, self.kv_size), dtype=torch.float32)
|
||||
|
||||
self.assertTrue(torch.allclose(q, expected_q, atol=1e-6))
|
||||
self.assertTrue(torch.allclose(k, expected_k, atol=1e-6))
|
||||
self.assertTrue(torch.equal(v, expected_v))
|
||||
32
tests/ut/multistream/test_base.py
Normal file
32
tests/ut/multistream/test_base.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.multistream.base import (MSAttentionMetadataSplitConfig,
|
||||
MSEventKey)
|
||||
|
||||
|
||||
class Testbase(TestBase):
|
||||
|
||||
def test_ms_event_key(self):
|
||||
self.assertEqual(MSEventKey.ATTN_COM_FINISH.value, 0)
|
||||
self.assertEqual(MSEventKey.ATTN_AR_FINISH.value, 1)
|
||||
self.assertEqual(MSEventKey.FFN_COM_FINISH.value, 2)
|
||||
self.assertEqual(MSEventKey.FFN_AR_FINISH.value, 3)
|
||||
self.assertEqual(MSEventKey.MOE_BEFORE_COMM.value, 4)
|
||||
self.assertEqual(MSEventKey.MOE_AFTER_COMM.value, 5)
|
||||
self.assertEqual(MSEventKey.MOE_SE_COMM_FINISH.value, 6)
|
||||
self.assertEqual(MSEventKey.MOE_SE_COMP_FINISH.value, 7)
|
||||
self.assertEqual(MSEventKey.MOE_GATE_FINISH.value, 8)
|
||||
|
||||
def test_ms_attention_metadata_split_config_default(self):
|
||||
config = MSAttentionMetadataSplitConfig()
|
||||
self.assertEqual(config.num_micro_batches, 2)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 256)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 64)
|
||||
|
||||
def test_ms_attention_metadata_split_config_custom(self):
|
||||
config = MSAttentionMetadataSplitConfig(
|
||||
num_micro_batches=4,
|
||||
min_total_tokens_to_split=512,
|
||||
min_prefill_tokens_to_split=128)
|
||||
self.assertEqual(config.num_micro_batches, 4)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 512)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 128)
|
||||
47
tests/ut/multistream/test_decorator.py
Normal file
47
tests/ut/multistream/test_decorator.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import pytest
|
||||
from pytest_mock import MockFixture
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.multistream.decorator import set_multistream_support
|
||||
|
||||
|
||||
class Context:
|
||||
|
||||
def __init__(self, attn_metadata=None):
|
||||
self.attn_metadata = attn_metadata
|
||||
|
||||
|
||||
class TestDecorator(PytestBase):
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'layer_context, microbatch_context, expected_metadata', [
|
||||
((-1, None, None), -1, {
|
||||
"original": True
|
||||
}),
|
||||
((-1, None, None), 0, {
|
||||
"original": True
|
||||
}),
|
||||
((0, None, None), -1, {
|
||||
"original": True
|
||||
}),
|
||||
((0, None, [{
|
||||
"new": True
|
||||
}]), 0, {
|
||||
"new": True
|
||||
}),
|
||||
])
|
||||
def test_decorator(self, mocker: MockFixture, layer_context,
|
||||
microbatch_context, expected_metadata):
|
||||
|
||||
def context_func():
|
||||
return Context(attn_metadata={"original": True})
|
||||
|
||||
mocker.patch(
|
||||
'vllm_ascend.multistream.decorator.get_multistream_layer_context',
|
||||
return_value=layer_context)
|
||||
mocker.patch(
|
||||
'vllm_ascend.multistream.decorator.get_multistream_microbatch_context',
|
||||
return_value=microbatch_context)
|
||||
|
||||
context = set_multistream_support()(context_func)()
|
||||
assert context.attn_metadata == expected_metadata
|
||||
198
tests/ut/multistream/test_layers.py
Normal file
198
tests/ut/multistream/test_layers.py
Normal file
@@ -0,0 +1,198 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
|
||||
MultiStreamPreTransformerLayer)
|
||||
from vllm_ascend.multistream.metadata import MultiStreamMetadata
|
||||
|
||||
|
||||
# === fixture: mock tensor input ===
|
||||
@pytest.fixture
|
||||
def input_tensors():
|
||||
return [torch.randn(2, 128), torch.randn(2, 128)]
|
||||
|
||||
|
||||
# === mock get_forward_context ===
|
||||
class DummyContext:
|
||||
|
||||
def __init__(self, attn_metadata):
|
||||
self.attn_metadata = attn_metadata
|
||||
|
||||
|
||||
class TestMultiStreamPreTransformerLayer(PytestBase):
|
||||
|
||||
# === test when multistream_metadata is None ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_no_multistream_metadata(self, mock_set_ctx, mock_get_ctx,
|
||||
input_tensors):
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata="dummy_meta")
|
||||
layer = MultiStreamPreTransformerLayer(multistream_metadata=None)
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out == "dummy_meta"
|
||||
assert input_out == input_tensors
|
||||
mock_set_ctx.assert_called_once_with(-1, None, None)
|
||||
|
||||
# === test when attn_metadata is None ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_no_attn_metadata(self, mock_set_ctx, mock_get_ctx,
|
||||
input_tensors):
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata=None)
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
layer = MultiStreamPreTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out is None
|
||||
assert input_out == input_tensors
|
||||
mock_set_ctx.assert_called_once_with(-1, None, None)
|
||||
|
||||
# === test when do_ms=False (no split needed) ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_no_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
|
||||
dummy_attn = "original_attn"
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
|
||||
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
dummy_metadata.split_micro_batch.return_value = (False, "same_attn",
|
||||
input_tensors, None)
|
||||
|
||||
layer = MultiStreamPreTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out == "same_attn"
|
||||
assert input_out == input_tensors
|
||||
mock_set_ctx.assert_called_once_with(-1, None, None)
|
||||
|
||||
# === test when do_ms=True (split occurred) ===
|
||||
@patch("vllm_ascend.multistream.layers.get_forward_context")
|
||||
@patch("vllm_ascend.multistream.layers.set_multistream_layer_context")
|
||||
def test_forward_split(self, mock_set_ctx, mock_get_ctx, input_tensors):
|
||||
dummy_attn = "original_attn"
|
||||
mock_get_ctx.return_value = DummyContext(attn_metadata=dummy_attn)
|
||||
|
||||
split_inputs = [[t[:1], t[1:]] for t in input_tensors]
|
||||
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
dummy_metadata.start_layer = 2
|
||||
dummy_metadata.split_micro_batch.return_value = (True,
|
||||
["attn1", "attn2"],
|
||||
split_inputs, None)
|
||||
|
||||
layer = MultiStreamPreTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
|
||||
attn_out, input_out = layer.forward(input_tensors)
|
||||
|
||||
assert attn_out == ["attn1", "attn2"]
|
||||
assert input_out == split_inputs
|
||||
mock_set_ctx.assert_called_once_with(2, dummy_metadata,
|
||||
["attn1", "attn2"])
|
||||
|
||||
|
||||
class TestMultiStreamPostTransformerLayer(PytestBase):
|
||||
|
||||
def test_post_forward_metadata_none(self, input_tensors):
|
||||
layer = MultiStreamPostTransformerLayer(multistream_metadata=None)
|
||||
output = layer.forward(input_tensors)
|
||||
assert output == input_tensors
|
||||
|
||||
dummy_metadata = MagicMock(spec=MultiStreamMetadata)
|
||||
dummy_metadata.ms_config = None
|
||||
layer = MultiStreamPostTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
output = layer.forward(input_tensors)
|
||||
assert output == input_tensors
|
||||
|
||||
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
|
||||
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
|
||||
def test_post_forward_normal_flow(self, mock_reset_ctx, mock_get_ctx,
|
||||
input_tensors):
|
||||
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
|
||||
calculate_stream=MagicMock(),
|
||||
communicate_stream=MagicMock(),
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
event_keys=[],
|
||||
multistream_config=None,
|
||||
)
|
||||
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
|
||||
dummy_metadata.ms_config.num_micro_batches = 4
|
||||
dummy_metadata.end_layer = 10
|
||||
|
||||
mock_get_ctx.return_value = (
|
||||
5, # layer_index
|
||||
dummy_metadata, # ms_metadata
|
||||
"dummy_attn_metadata" # ms_attn_metadata
|
||||
)
|
||||
|
||||
dummy_metadata.merge_micro_batches.return_value = "merged_result"
|
||||
|
||||
layer = MultiStreamPostTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
output = layer.forward(input_tensors)
|
||||
|
||||
# check wait_event
|
||||
dummy_metadata.try_wait_event.assert_called_once_with(
|
||||
9, # end_layer - 1
|
||||
3, # num_micro_batches - 1
|
||||
MSEventKey.FFN_AR_FINISH)
|
||||
mock_reset_ctx.assert_called_once()
|
||||
assert output == "merged_result"
|
||||
|
||||
@patch("vllm_ascend.multistream.layers.get_multistream_layer_context")
|
||||
@patch("vllm_ascend.multistream.layers.reset_multistream_layer_context")
|
||||
def test_post_forward_with_custom_wait_layer(self, mock_reset_ctx,
|
||||
mock_get_ctx, input_tensors):
|
||||
A_instance_of_MultiStreamMetadata = MultiStreamMetadata(
|
||||
calculate_stream=MagicMock(),
|
||||
communicate_stream=MagicMock(),
|
||||
start_layer=0,
|
||||
end_layer=1,
|
||||
event_keys=[],
|
||||
multistream_config=None,
|
||||
)
|
||||
dummy_metadata = MagicMock(spec=A_instance_of_MultiStreamMetadata)
|
||||
dummy_metadata.ms_config.num_micro_batches = 4
|
||||
dummy_metadata.end_layer = 10
|
||||
|
||||
mock_get_ctx.return_value = (
|
||||
3, # layer_index
|
||||
dummy_metadata,
|
||||
"dummy_attn_metadata")
|
||||
|
||||
dummy_metadata.merge_micro_batches.return_value = "merged_result"
|
||||
|
||||
layer = MultiStreamPostTransformerLayer(
|
||||
multistream_metadata=dummy_metadata)
|
||||
output = layer.forward(input_tensors, wait_layer_index=7)
|
||||
|
||||
dummy_metadata.try_wait_event.assert_called_once_with(
|
||||
7, 3, MSEventKey.FFN_AR_FINISH)
|
||||
mock_reset_ctx.assert_called_once()
|
||||
assert output == "merged_result"
|
||||
246
tests/ut/multistream/test_metadata.py
Normal file
246
tests/ut/multistream/test_metadata.py
Normal file
@@ -0,0 +1,246 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||
MultiStreamMetadata,
|
||||
MultiStreamStepMetadata,
|
||||
split_micro_batches_tensors)
|
||||
|
||||
|
||||
class TestMetaData(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.test_tensors_list = [torch.randn(100, 1024) for i in range(3)]
|
||||
self.test_tensors = torch.randn(100, 1024)
|
||||
self.test_tensors_dict = {
|
||||
'query': torch.randn(100, 1024),
|
||||
'key': torch.randn(100, 1024),
|
||||
'value': torch.randn(100, 1024)
|
||||
}
|
||||
self.split_index = 50
|
||||
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
|
||||
self.metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
def test_split_micro_batches_tensors(self):
|
||||
test_tensors_list_res = split_micro_batches_tensors(
|
||||
self.test_tensors_list, self.split_index)
|
||||
test_tensors_res = split_micro_batches_tensors(self.test_tensors,
|
||||
self.split_index)
|
||||
keys = ['query', 'key', 'value']
|
||||
test_tensors_dict_res = split_micro_batches_tensors(
|
||||
self.test_tensors_dict, self.split_index, keys)
|
||||
for i in range(3):
|
||||
self.assertEqual(len(test_tensors_list_res[i][0]),
|
||||
self.split_index)
|
||||
|
||||
self.assertEqual(
|
||||
len(test_tensors_list_res[i][0]) +
|
||||
len(test_tensors_list_res[i][1]), 100)
|
||||
|
||||
self.assertEqual(len(test_tensors_res[0]), self.split_index)
|
||||
self.assertEqual(
|
||||
len(test_tensors_res[0]) + len(test_tensors_res[1]), 100)
|
||||
|
||||
for key in keys:
|
||||
self.assertEqual(len(test_tensors_dict_res[0][key]),
|
||||
self.split_index)
|
||||
self.assertEqual(
|
||||
len(test_tensors_dict_res[0][key]) +
|
||||
len(test_tensors_dict_res[1][key]), 100)
|
||||
|
||||
def test_default_init_multistream_step_metadata(self):
|
||||
metadata = MultiStreamStepMetadata()
|
||||
self.assertIsNone(metadata.comm_stream)
|
||||
self.assertIsNone(metadata.before_comm_event)
|
||||
self.assertIsNone(metadata.after_comm_event)
|
||||
|
||||
def test_custom_init_multistream_step_metadata(self):
|
||||
mockStream = MagicMock(spec=torch.npu.Stream)
|
||||
mockEvent1 = MagicMock(spec=torch.npu.Event)
|
||||
mockEvent2 = MagicMock(spec=torch.npu.Event)
|
||||
|
||||
metadata = MultiStreamStepMetadata(mockStream, mockEvent1, mockEvent2)
|
||||
self.assertEqual(metadata.comm_stream, mockStream)
|
||||
self.assertEqual(metadata.before_comm_event, mockEvent1)
|
||||
self.assertEqual(metadata.after_comm_event, mockEvent2)
|
||||
|
||||
def test_default_init_multistream_config(self):
|
||||
config = MultiStreamConfig()
|
||||
self.assertEqual(config.min_total_tokens_to_split, 256)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 64)
|
||||
self.assertEqual(config.num_micro_batches, 2)
|
||||
self.assertEqual(config.imbalance_ratio, 0.1)
|
||||
|
||||
def test_custom_init_multistream_config(self):
|
||||
config = MultiStreamConfig(512, 128, 1, 0.2)
|
||||
self.assertEqual(config.min_total_tokens_to_split, 512)
|
||||
self.assertEqual(config.min_prefill_tokens_to_split, 128)
|
||||
self.assertEqual(config.num_micro_batches, 1)
|
||||
self.assertEqual(config.imbalance_ratio, 0.2)
|
||||
|
||||
def test_init_multistream_metadata(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
|
||||
event_keys = [MagicMock()]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
|
||||
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
self.assertEqual(metadata.calculate_stream, mock_stream)
|
||||
self.assertEqual(metadata.communicate_stream, mock_stream)
|
||||
self.assertEqual(metadata.start_layer, 1)
|
||||
self.assertEqual(metadata.end_layer, 3)
|
||||
self.assertEqual(metadata.ms_config, multistream_config)
|
||||
self.assertTrue(metadata.causal_lm)
|
||||
|
||||
def test_build_events(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MultiStreamConfig(
|
||||
num_micro_batches=2,
|
||||
min_total_tokens_to_split=256,
|
||||
min_prefill_tokens_to_split=64)
|
||||
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
expected_events = {
|
||||
0: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
},
|
||||
1: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
},
|
||||
2: {
|
||||
0: {
|
||||
event_keys[0]: mock_event
|
||||
},
|
||||
1: {
|
||||
event_keys[0]: mock_event
|
||||
}
|
||||
}
|
||||
}
|
||||
self.assertEqual(metadata.ms_events, expected_events)
|
||||
|
||||
def test_build_ms_split_config(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
multistream_config.num_micro_batches = 2
|
||||
multistream_config.min_total_tokens_to_split = 256
|
||||
multistream_config.min_prefill_tokens_to_split = 64
|
||||
|
||||
metadata = MultiStreamMetadata(calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
self.assertIsNotNone(metadata.ms_split_config)
|
||||
self.assertEqual(metadata.ms_split_config.num_micro_batches,
|
||||
multistream_config.num_micro_batches)
|
||||
self.assertEqual(metadata.ms_split_config.min_total_tokens_to_split,
|
||||
multistream_config.min_total_tokens_to_split)
|
||||
self.assertEqual(metadata.ms_split_config.min_prefill_tokens_to_split,
|
||||
multistream_config.min_prefill_tokens_to_split)
|
||||
|
||||
def test_try_wait_event(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
metadata.try_wait_event(layer_index=1,
|
||||
micro_batch_index=0,
|
||||
event_key=event_keys[0])
|
||||
mock_event.wait.assert_called_once()
|
||||
|
||||
def test_try_record_event(self):
|
||||
mock_stream = MagicMock(spec=torch.npu.Stream)
|
||||
mock_event = MagicMock(spec=torch.npu.Event)
|
||||
event_keys = [MagicMock(spec=MSEventKey)]
|
||||
multistream_config = MagicMock(spec=MultiStreamConfig)
|
||||
with patch('torch.npu.Event', return_value=mock_event):
|
||||
metadata = MultiStreamMetadata(
|
||||
calculate_stream=mock_stream,
|
||||
communicate_stream=mock_stream,
|
||||
start_layer=1,
|
||||
end_layer=3,
|
||||
event_keys=event_keys,
|
||||
multistream_config=multistream_config)
|
||||
|
||||
metadata.try_record_event(layer_index=1,
|
||||
micro_batch_index=0,
|
||||
event_key=event_keys[0])
|
||||
mock_event.record.assert_called_once()
|
||||
|
||||
def test_merge_batches_none_input(self):
|
||||
input_tensors = None
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_merge_batches_single_tensor_input(self):
|
||||
input_tensors = [torch.tensor([1, 2, 3])]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3])))
|
||||
|
||||
def test_merge_batches_list_of_tensors_input(self):
|
||||
input_tensors = [torch.tensor([1, 2]), torch.tensor([3, 4])]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result, input_tensors)
|
||||
|
||||
def test_merge_batches_nested_list_input(self):
|
||||
input_tensors = [[torch.tensor([1, 2]),
|
||||
torch.tensor([3, 4])],
|
||||
[torch.tensor([5, 6]),
|
||||
torch.tensor([7, 8])]]
|
||||
result = self.metadata.merge_micro_batches(input_tensors)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], torch.tensor([1, 2, 3, 4])))
|
||||
self.assertTrue(torch.equal(result[1], torch.tensor([5, 6, 7, 8])))
|
||||
147
tests/ut/multistream/test_ms_split.py
Normal file
147
tests/ut/multistream/test_ms_split.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import (compute_split_seq_index,
|
||||
model_input_split_v1_mla_attn,
|
||||
split_attn_int_type,
|
||||
split_attn_tensor_type)
|
||||
|
||||
|
||||
class TestMsSplit(TestBase):
|
||||
|
||||
def test_decode_only(self):
|
||||
result = compute_split_seq_index(
|
||||
query_lens=None,
|
||||
attn_state=AscendAttentionState.DecodeOnly,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [5, 5])
|
||||
|
||||
def test_perfect_balance(self):
|
||||
query_lens = [2, 3, 5]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [5, 2])
|
||||
|
||||
def test_imbalance(self):
|
||||
query_lens = [1, 2, 3, 4]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_query_lens_none(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
compute_split_seq_index(
|
||||
query_lens=None,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
|
||||
def test_empty_query_lens(self):
|
||||
query_lens: list[int] = []
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_single_query_len(self):
|
||||
query_lens = [10]
|
||||
result = compute_split_seq_index(
|
||||
query_lens=query_lens,
|
||||
attn_state=AscendAttentionState.PrefillNoCache,
|
||||
num_tokens=10)
|
||||
self.assertEqual(result, [0, 0])
|
||||
|
||||
def test_split_attn_tensor_type_middle(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 3
|
||||
expected_result = [torch.tensor([1, 2, 3]), torch.tensor([4, 5])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_start(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 0
|
||||
expected_result = [torch.tensor([]), torch.tensor([1, 2, 3, 4, 5])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_end(self):
|
||||
input_tensor = torch.tensor([1, 2, 3, 4, 5])
|
||||
index = 5
|
||||
expected_result = [torch.tensor([1, 2, 3, 4, 5]), torch.tensor([])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_tensor_type_empty_tensor(self):
|
||||
input_tensor = torch.tensor([])
|
||||
index = 0
|
||||
expected_result = [torch.tensor([]), torch.tensor([])]
|
||||
result = split_attn_tensor_type(input_tensor, index)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertTrue(torch.equal(result[0], expected_result[0]))
|
||||
self.assertTrue(torch.equal(result[1], expected_result[1]))
|
||||
|
||||
def test_split_attn_int_type_index_greater_than_var(self):
|
||||
var = 5
|
||||
index = 10
|
||||
expected_result = [5, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_equal_to_var(self):
|
||||
var = 5
|
||||
index = 5
|
||||
expected_result = [5, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_less_than_var(self):
|
||||
var = 10
|
||||
index = 5
|
||||
expected_result = [5, 5]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_index_zero(self):
|
||||
var = 10
|
||||
index = 0
|
||||
expected_result = [0, 10]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_var_zero(self):
|
||||
var = 0
|
||||
index = 5
|
||||
expected_result = [0, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_attn_int_type_both_zero(self):
|
||||
var = 0
|
||||
index = 0
|
||||
expected_result = [0, 0]
|
||||
result = split_attn_int_type(var, index)
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_split_v1_mla_attn_input_none(self):
|
||||
attn_metadata = None
|
||||
ascendMLAPrefillMetadata = MagicMock()
|
||||
ms_split_config = MSAttentionMetadataSplitConfig(num_micro_batches=1)
|
||||
result = model_input_split_v1_mla_attn(attn_metadata,
|
||||
ascendMLAPrefillMetadata,
|
||||
ms_split_config)
|
||||
self.assertEqual(result, [None])
|
||||
17
tests/ut/ops/expert_map.json
Normal file
17
tests/ut/ops/expert_map.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"moe_layer_count":
|
||||
1,
|
||||
"layer_list": [{
|
||||
"layer_id":
|
||||
0,
|
||||
"device_count":
|
||||
2,
|
||||
"device_list": [{
|
||||
"device_id": 0,
|
||||
"device_expert": [7, 2, 0, 3, 5]
|
||||
}, {
|
||||
"device_id": 1,
|
||||
"device_expert": [6, 1, 4, 7, 2]
|
||||
}]
|
||||
}]
|
||||
}
|
||||
61
tests/ut/ops/test_activation.py
Normal file
61
tests/ut/ops/test_activation.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.activation import QuickGELU, SiluAndMul
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tensor():
|
||||
return torch.randn(4, 8, dtype=torch.float16)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_fast_gelu", side_effect=lambda x: x + 1)
|
||||
def test_QuickGELU_forward(mock_gelu, dummy_tensor):
|
||||
layer = QuickGELU()
|
||||
out = layer.forward(dummy_tensor)
|
||||
|
||||
expected_out = dummy_tensor + 1
|
||||
assert torch.allclose(out, expected_out)
|
||||
|
||||
mock_gelu.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@patch("torch_npu.npu_swiglu", side_effect=lambda x: x + 1)
|
||||
def test_SiluAndMul_forward(mock_swiglu, is_310p_return, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = SiluAndMul()
|
||||
out = layer.forward(dummy_tensor)
|
||||
|
||||
if is_310p_return:
|
||||
expected_arg = dummy_tensor.to(torch.float32)
|
||||
else:
|
||||
expected_arg = dummy_tensor
|
||||
|
||||
# assert mock_swiglu.call_count == 1
|
||||
mock_swiglu.assert_called_once()
|
||||
|
||||
actual_arg = mock_swiglu.call_args[0][0]
|
||||
assert torch.allclose(
|
||||
actual_arg,
|
||||
expected_arg), "npu_swiglu called with unexpected input"
|
||||
|
||||
expected_out = dummy_tensor + 1
|
||||
assert torch.allclose(out, expected_out)
|
||||
69
tests/ut/ops/test_common_fused_moe.py
Normal file
69
tests/ut/ops/test_common_fused_moe.py
Normal file
@@ -0,0 +1,69 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.common_fused_moe import fused_experts_moge
|
||||
|
||||
|
||||
class TestFusedExpertsMoGE(TestBase):
|
||||
|
||||
def test_fused_experts_moge(self):
|
||||
with patch('torch_npu.npu_grouped_matmul') as mock_grouped_matmul, \
|
||||
patch('torch_npu.npu_swiglu') as mock_swiglu, \
|
||||
patch('vllm_ascend.utils.is_310p') as mock_is_310p:
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_grouped_matmul.side_effect = lambda x, weight, **kwargs: [
|
||||
torch.randn(x[0].shape[0], weight[0].shape[1])
|
||||
]
|
||||
|
||||
mock_swiglu.side_effect = lambda x: x
|
||||
|
||||
hidden_states = torch.randn(4, 128)
|
||||
w1 = torch.randn(4, 256, 128)
|
||||
w2 = torch.randn(4, 128, 128)
|
||||
topk_weights = torch.rand(4, 1)
|
||||
topk_ids = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)
|
||||
top_k = 1
|
||||
global_num_experts = 4
|
||||
|
||||
moe_parallel_config = type(
|
||||
'MockConfig', (), {
|
||||
'ep_size': 1,
|
||||
'tp_size': 1,
|
||||
'dp_size': 1,
|
||||
'tp_rank': 0,
|
||||
'dp_rank': 0,
|
||||
'ep_rank': 0,
|
||||
'use_ep': True
|
||||
})()
|
||||
|
||||
output = fused_experts_moge(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (4, 128))
|
||||
141
tests/ut/ops/test_expert_load_balancer.py
Normal file
141
tests/ut/ops/test_expert_load_balancer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import List, TypedDict
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
|
||||
|
||||
class TestExpertLoadBalancer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_TEST_DIR = os.path.dirname(__file__)
|
||||
json_file = _TEST_DIR + "/expert_map.json"
|
||||
with open(json_file, 'r') as f:
|
||||
self.expert_map: MockData = json.load(f)
|
||||
|
||||
self.expert_load_balancer = ExpertLoadBalancer(json_file,
|
||||
global_expert_num=8)
|
||||
|
||||
def test_init(self):
|
||||
|
||||
self.assertIsInstance(self.expert_load_balancer.expert_map_tensor,
|
||||
torch.Tensor)
|
||||
self.assertEqual(self.expert_load_balancer.layers_num,
|
||||
self.expert_map["moe_layer_count"])
|
||||
self.assertEqual(self.expert_load_balancer.ranks_num,
|
||||
self.expert_map["layer_list"][0]["device_count"])
|
||||
|
||||
def test_generate_index_dicts(self):
|
||||
tensor_2d = torch.tensor([[7, 2, 0, 3, 5], [6, 1, 4, 7, 2]])
|
||||
result = self.expert_load_balancer.generate_index_dicts(tensor_2d)
|
||||
expected_result = [{
|
||||
7: 0,
|
||||
2: 1,
|
||||
0: 2,
|
||||
3: 3,
|
||||
5: 4
|
||||
}, {
|
||||
6: 5,
|
||||
1: 6,
|
||||
4: 7,
|
||||
7: 8,
|
||||
2: 9
|
||||
}]
|
||||
self.assertEqual(result, expected_result)
|
||||
|
||||
def test_generate_expert_placement_map(self):
|
||||
expert_placement_map = self.expert_load_balancer.generate_expert_placement_map(
|
||||
)
|
||||
self.assertEqual(expert_placement_map.shape,
|
||||
(self.expert_load_balancer.layers_num,
|
||||
self.expert_load_balancer.ranks_num, 8))
|
||||
self.assertTrue(torch.all(expert_placement_map >= -1))
|
||||
|
||||
def test_generate_log2phy_expert_map(self):
|
||||
layer_id = 0
|
||||
log2phy_map = self.expert_load_balancer.generate_log2phy_expert_map(
|
||||
layer_id)
|
||||
self.assertEqual(log2phy_map.shape,
|
||||
(self.expert_load_balancer.ranks_num, 8))
|
||||
self.assertTrue(torch.all(log2phy_map >= -1))
|
||||
|
||||
@mock.patch("torch_npu.npu._lazy_init")
|
||||
@mock.patch("torch.npu.current_device", return_value="cpu")
|
||||
def test_get_rank_placement_map(self, mock_current_device, mock_lazy_init):
|
||||
layer_id = 0
|
||||
rank_id = 0
|
||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||
layer_id, rank_id)
|
||||
self.assertEqual(rank_local_expert_num, 5)
|
||||
expected_tensor = torch.tensor([2, -1, 1, 3, -1, 4, -1, 0],
|
||||
dtype=torch.int32).to(
|
||||
rank_expert_map.device)
|
||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||
|
||||
rank_id = 1
|
||||
rank_local_expert_num, rank_expert_map = self.expert_load_balancer.get_rank_placement_map(
|
||||
layer_id, rank_id)
|
||||
expected_tensor = torch.tensor([-1, 1, 4, -1, 2, -1, 0, 3],
|
||||
dtype=torch.int32).to(
|
||||
rank_expert_map.device)
|
||||
self.assertTrue(rank_expert_map.equal(expected_tensor))
|
||||
|
||||
def test_get_rank_log2phy_map(self):
|
||||
layer_id = 0
|
||||
rank_id = 0
|
||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
layer_id, rank_id)
|
||||
expected_tensor = torch.tensor([2, 6, 1, 3, 7, 4, 5, 0],
|
||||
dtype=torch.int32).to(
|
||||
log2phy_map.device)
|
||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||
|
||||
rank_id = 1
|
||||
log2phy_map = self.expert_load_balancer.get_rank_log2phy_map(
|
||||
layer_id, rank_id)
|
||||
expected_tensor = torch.tensor([2, 6, 9, 3, 7, 4, 5, 8],
|
||||
dtype=torch.int32).to(
|
||||
log2phy_map.device)
|
||||
self.assertTrue(log2phy_map.equal(expected_tensor))
|
||||
|
||||
def test_get_global_redundant_expert_num(self):
|
||||
redundant_expert_num = self.expert_load_balancer.get_global_redundant_expert_num(
|
||||
)
|
||||
expected_redundant_expert_num = len(self.expert_map["layer_list"][0]["device_list"][0]["device_expert"]) * \
|
||||
self.expert_map["layer_list"][0]["device_count"] - 8
|
||||
self.assertEqual(redundant_expert_num, expected_redundant_expert_num)
|
||||
741
tests/ut/ops/test_fused_ops.py
Normal file
741
tests/ut/ops/test_fused_ops.py
Normal file
@@ -0,0 +1,741 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import (FusedMoEState,
|
||||
_get_fused_moe_state)
|
||||
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
|
||||
AscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
|
||||
def mock_ep_and_mc2_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.rank = 0
|
||||
mock_group.world_size = 4
|
||||
mock_group.device_group = "mock_group_ep"
|
||||
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
|
||||
return mock_group
|
||||
|
||||
|
||||
def mock_dp_and_tp_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mock_group.device_group = "mock_group"
|
||||
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
|
||||
return mock_group
|
||||
|
||||
|
||||
def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
mock_setup_token_dispatchers = MagicMock()
|
||||
mock_token_dispatcher_with_allgather = MagicMock()
|
||||
mock_token_dispatcher_with_all2allv = MagicMock()
|
||||
mock_token_dispatcher_with_mc2 = MagicMock()
|
||||
|
||||
mock_dispatch_result_allgather = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([8, 16], dtype=torch.int64),
|
||||
"group_list_type": 0,
|
||||
}
|
||||
mock_combine_result_allgather = torch.randn(16, 2)
|
||||
|
||||
mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather
|
||||
mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather
|
||||
|
||||
mock_dispatch_result_all2allv = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
}
|
||||
mock_combine_result_all2allv = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv
|
||||
mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv
|
||||
|
||||
mock_dispatch_result_mc2 = {
|
||||
"hidden_states": torch.randn(16, 2),
|
||||
"group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64),
|
||||
"group_list_type": 1,
|
||||
"dynamic_scale": None,
|
||||
"assist_info_for_combine": torch.randn(16, 2),
|
||||
"ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32),
|
||||
}
|
||||
mock_combine_result_mc2 = torch.randn(16, 2)
|
||||
mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2
|
||||
mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2
|
||||
|
||||
captured_dispatchers = {}
|
||||
|
||||
def capture_register(dispatcher_instance):
|
||||
key = dispatcher_instance.__class__.__name__
|
||||
captured_dispatchers[key] = dispatcher_instance
|
||||
if key == 'TokenDispatcherWithAllGather':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_allgather
|
||||
elif key == 'TokenDispatcherWithAll2AllV':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_all2allv
|
||||
elif key == 'TokenDispatcherWithMC2':
|
||||
captured_dispatchers[key] = mock_token_dispatcher_with_mc2
|
||||
|
||||
mock_register_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher',
|
||||
side_effect=capture_register)
|
||||
|
||||
mock_get_token_dispatcher_patcher = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher',
|
||||
side_effect=lambda name: captured_dispatchers.get(name))
|
||||
|
||||
default_mock_token_dispatcher = mock_token_dispatcher_with_allgather
|
||||
|
||||
mock_forward_context_obj = MagicMock(
|
||||
fused_moe_state=FusedMoEState.AllGather,
|
||||
token_dispatcher=default_mock_token_dispatcher,
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]),
|
||||
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
||||
padded_num_tokens=16,
|
||||
with_quant=False)
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather'), \
|
||||
patch('torch.distributed.all_to_all_single'), \
|
||||
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \
|
||||
patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.ops.fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_forward_context',
|
||||
return_value=mock_forward_context_obj), \
|
||||
patch('vllm_ascend.ops.fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
||||
patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers), \
|
||||
patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context',
|
||||
return_value=mock_forward_context_obj):
|
||||
|
||||
yield {
|
||||
'mock_forward_context_obj': mock_forward_context_obj,
|
||||
'mock_token_dispatcher_with_allgather':
|
||||
mock_token_dispatcher_with_allgather,
|
||||
'mock_token_dispatcher_with_all2allv':
|
||||
mock_token_dispatcher_with_all2allv,
|
||||
'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2,
|
||||
}
|
||||
|
||||
mock_register_token_dispatcher_patcher.stop()
|
||||
mock_get_token_dispatcher_patcher.stop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
|
||||
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
None
|
||||
)), \
|
||||
patch('torch_npu.npu_moe_init_routing', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
|
||||
torch.randn(8, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_combine", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_grouped_matmul", return_value=(
|
||||
[torch.randn(16, 2)]
|
||||
)), \
|
||||
patch("torch_npu.npu_swiglu", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)):
|
||||
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
||||
torch.randn(16, 2))), \
|
||||
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
||||
torch.randn(16, 2))):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_moe_config():
|
||||
return {
|
||||
'num_experts': 8,
|
||||
'top_k': 2,
|
||||
'hidden_size': 512,
|
||||
'intermediate_size': 1024
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
return AscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
|
||||
|
||||
class MockQuantMethod(nn.Module):
|
||||
|
||||
def __init__(self, shared_experts, num_tokens):
|
||||
super().__init__()
|
||||
if shared_experts:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
|
||||
torch.randn(num_tokens, 10)))
|
||||
else:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.moe)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, hidden_states: torch.Tensor,
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class TestAscendFusedMoe:
|
||||
|
||||
def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
||||
layer = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
layer.w13_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['intermediate_size'] * 2,
|
||||
default_moe_config['hidden_size']))
|
||||
layer.w2_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['hidden_size'],
|
||||
default_moe_config['intermediate_size']))
|
||||
|
||||
assert layer.num_experts == default_moe_config['num_experts']
|
||||
assert layer.top_k == default_moe_config['top_k']
|
||||
assert hasattr(layer, 'w13_weight')
|
||||
assert hasattr(layer, 'w2_weight')
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['use_grouped_topk'] = True
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['scoring_func'] = "random"
|
||||
layer = AscendFusedMoE(**error_config)
|
||||
|
||||
def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
|
||||
moe = AscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert moe.quant_method == mock_quant_method
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"others_param",
|
||||
[[None,
|
||||
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
|
||||
[2, None, False, 5, None], [None, None, True, 5, None],
|
||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
||||
[None, None, False, 5, 1]])
|
||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
||||
|
||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
||||
inputs = torch.randn(num_tokens, 32)
|
||||
router_logits = torch.randn(num_tokens, 8)
|
||||
moe = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
if ep_size == 1:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
if shared_experts:
|
||||
assert output[0].shape == (num_tokens, 32)
|
||||
assert output[1].shape == (num_tokens, 10)
|
||||
else:
|
||||
assert output.shape == (num_tokens, 32)
|
||||
|
||||
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
|
||||
default_moe_config):
|
||||
inputs = torch.randn(5, 32)
|
||||
router_logits = torch.randn(5, 8)
|
||||
moe = AscendFusedMoE(**default_moe_config)
|
||||
|
||||
moe.quant_method = MockQuantMethod(None, 5)
|
||||
output = moe._forward_ms_fused_moe_comp(inputs,
|
||||
router_logits,
|
||||
is_prefill=False,
|
||||
real_top_k=1)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
assert output.shape == (5, 32)
|
||||
|
||||
|
||||
class TestAscendUnquantizedFusedMoEMethod:
|
||||
|
||||
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
|
||||
layer = MagicMock()
|
||||
layer.w13_weight.data = torch.randn(16, 32)
|
||||
layer.w2_weight.data = torch.randn(16, 32)
|
||||
|
||||
with patch('torch_npu.npu_format_cast', mock_npu_format_cast), \
|
||||
patch('vllm_ascend.utils.is_310p', return_value=False):
|
||||
moe_method.process_weights_after_loading(layer)
|
||||
|
||||
assert isinstance(layer.w13_weight, torch.nn.Parameter)
|
||||
assert isinstance(layer.w2_weight, torch.nn.Parameter)
|
||||
assert not layer.w13_weight.requires_grad
|
||||
assert not layer.w2_weight.requires_grad
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
|
||||
if ep_size == 1:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_allgather']
|
||||
elif ep_size < 16:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_all2allv']
|
||||
else:
|
||||
selected_token_dispatcher = mock_dist_env[
|
||||
'mock_token_dispatcher_with_mc2']
|
||||
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, True),
|
||||
with_quant=False,
|
||||
token_dispatcher=selected_token_dispatcher)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
|
||||
local_num_experts = 2
|
||||
hidden_size = 2
|
||||
intermediate_size_per_partition = 4
|
||||
layer.w13_weight = torch.randn(local_num_experts,
|
||||
intermediate_size_per_partition * 2,
|
||||
hidden_size)
|
||||
layer.w2_weight = torch.randn(local_num_experts, hidden_size,
|
||||
intermediate_size_per_partition)
|
||||
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=128,
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
expected_shape = (16, 2)
|
||||
|
||||
assert result.shape == expected_shape
|
||||
|
||||
|
||||
class TestExpertsSelector:
|
||||
|
||||
@pytest.mark.parametrize("global_num_experts", [[256], [128]])
|
||||
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
||||
global_num_experts):
|
||||
|
||||
x = torch.randn(8, 2)
|
||||
router_logits = torch.randn(8, 2)
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
topk_group=None,
|
||||
num_expert_group=None,
|
||||
custom_routing_function=None,
|
||||
scoring_func="softmax",
|
||||
e_score_correction_bias=None,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
assert topk_weights.shape == (8, 2)
|
||||
assert topk_ids.shape == (8, 2)
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP(TestBase):
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
@patch('torch_npu.npu_dequant_swiglu_quant')
|
||||
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p,
|
||||
mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.fused_moe_state = FusedMoEState.MC2
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 20),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
|
||||
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
||||
|
||||
mock_npu_dequant.return_value = (torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16),
|
||||
torch.randn(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.float32)
|
||||
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
self.assertEqual(mock_forward_context.fused_moe_state,
|
||||
FusedMoEState.MC2)
|
||||
|
||||
mock_npu_dynamic_quant.assert_called()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
|
||||
mock_npu_dequant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization(self,
|
||||
mock_npu_dynamic_quant,
|
||||
mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul,
|
||||
mock_is_310p):
|
||||
mock_is_310p.return_value = False
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 40, dtype=torch.float16)
|
||||
], [torch.randn(10, 20, dtype=torch.float16)]]
|
||||
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
||||
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.get_forward_context')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_get_forward_context):
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_forward_context.with_quant = True
|
||||
mock_forward_context.fused_moe_state = "NOT_MC2"
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
|
||||
mock_npu_grouped_matmul.side_effect = [[
|
||||
torch.randn(10, 40, dtype=torch.bfloat16)
|
||||
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
||||
|
||||
mock_npu_swiglu.return_value = torch.randn(10,
|
||||
40,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
||||
127, (10, 40),
|
||||
dtype=torch.int8),
|
||||
torch.rand(10,
|
||||
1,
|
||||
dtype=torch.float32))
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
||||
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
||||
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
mock_npu_dynamic_quant.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.moe_mlp.is_310p')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_dynamic_quant')
|
||||
def test_unified_apply_mlp_without_quantization_310p(
|
||||
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
||||
mock_npu_grouped_matmul, mock_is_310p):
|
||||
mock_is_310p.return_value = True
|
||||
|
||||
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
||||
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
|
||||
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
|
||||
[mock_gmm2_out]]
|
||||
|
||||
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
||||
|
||||
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
||||
|
||||
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
||||
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
||||
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
|
||||
mock_is_310p.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
|
||||
self.assertEqual(result.shape, hidden_states.shape)
|
||||
self.assertEqual(result.dtype, torch.float16)
|
||||
53
tests/ut/ops/test_layernorm.py
Normal file
53
tests/ut/ops/test_layernorm.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_tensor():
|
||||
return torch.randn(4, 8, dtype=torch.float16)
|
||||
|
||||
|
||||
def mock_rms_norm(x, weight, eps):
|
||||
return x + 1, None
|
||||
|
||||
|
||||
def mock_add_rms_norm(x, residual, weight, eps):
|
||||
return 2 * x, None, 2 * residual
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_310p_return", [True, False])
|
||||
@pytest.mark.parametrize("residual",
|
||||
[None, torch.randn(4, 8, dtype=torch.float32)])
|
||||
@patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
@patch("torch_npu.npu_add_rms_norm", side_effect=mock_add_rms_norm)
|
||||
def test_RMSNorm_forward(mock_add_rmsnorm, mock_rmsnorm, is_310p_return,
|
||||
residual, dummy_tensor):
|
||||
|
||||
with patch("vllm_ascend.utils.is_310p", return_value=is_310p_return):
|
||||
layer = RMSNorm(hidden_size=32, eps=1e-05)
|
||||
if residual is not None:
|
||||
out_x, out_residual = layer.forward_oot(dummy_tensor, residual)
|
||||
|
||||
if is_310p_return:
|
||||
expected_arg_x = dummy_tensor + residual.to(dummy_tensor.dtype)
|
||||
expected_out_x = expected_arg_x + 1
|
||||
expected_out_residual = expected_arg_x.to(residual.dtype)
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
expected_out_x = 2 * dummy_tensor
|
||||
expected_out_residual = 2 * residual
|
||||
mock_add_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
assert torch.allclose(out_residual, expected_out_residual)
|
||||
else:
|
||||
out_x = layer.forward(dummy_tensor, residual)
|
||||
expected_out_x = dummy_tensor + 1
|
||||
|
||||
mock_rmsnorm.assert_called_once()
|
||||
assert torch.allclose(out_x, expected_out_x)
|
||||
363
tests/ut/ops/test_linear.py
Normal file
363
tests/ut/ops/test_linear.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.linear import (AscendMlpColumnParallelLinear,
|
||||
AscendMlpMergedColumnParallelLinear,
|
||||
AscendMlpRowParallelLinear, LinearBase,
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class TestAscendMlpRowParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
self.tensor_parallel_world_size = 2
|
||||
self.tensor_parallel_rank = 0
|
||||
self.mlp_tensor_parallel_world_size = 2
|
||||
self.mlp_tensor_parallel_rank = 1
|
||||
|
||||
self.get_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_world_size',
|
||||
return_value=self.tensor_parallel_world_size)
|
||||
self.get_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_tensor_model_parallel_rank',
|
||||
return_value=self.tensor_parallel_rank)
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size',
|
||||
return_value=self.mlp_tensor_parallel_world_size)
|
||||
self.get_mlp_tensor_model_parallel_rank_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank',
|
||||
return_value=self.mlp_tensor_parallel_rank)
|
||||
|
||||
self.get_tensor_model_parallel_world_size_mock = \
|
||||
self.get_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_tensor_model_parallel_rank_mock = \
|
||||
self.get_tensor_model_parallel_rank_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_world_size_mock = \
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.start()
|
||||
self.get_mlp_tensor_model_parallel_rank_mock = \
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.start()
|
||||
|
||||
self.split_tensor_along_last_dim_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.split_tensor_along_last_dim',
|
||||
return_value=(torch.randn(10, 8), torch.randn(10, 8)))
|
||||
self.tensor_model_parallel_all_reduce_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_reduce_mock = \
|
||||
self.tensor_model_parallel_all_reduce_patch.start()
|
||||
self.split_tensor_along_last_dim_mock = \
|
||||
self.split_tensor_along_last_dim_patch.start()
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.reduce_scatter = \
|
||||
mock.MagicMock()
|
||||
|
||||
def tearDown(self):
|
||||
self.get_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_tensor_model_parallel_rank_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_world_size_patch.stop()
|
||||
self.get_mlp_tensor_model_parallel_rank_patch.stop()
|
||||
self.split_tensor_along_last_dim_patch.stop()
|
||||
self.tensor_model_parallel_all_reduce_patch.stop()
|
||||
self.get_mlp_tp_group_patch.stop()
|
||||
|
||||
def test_init_with_down_proj_prefix(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj")
|
||||
self.assertEqual(layer.tp_size, self.mlp_tensor_parallel_world_size)
|
||||
self.assertEqual(layer.tp_rank, self.mlp_tensor_parallel_rank)
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
|
||||
def test_forward_with_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="down_proj",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8) # (batch_size, input_size)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
|
||||
def test_forward_without_mlp_optimize(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="other",
|
||||
input_is_parallel=False,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once_with(
|
||||
input_tensor, num_partitions=layer.tp_size)
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_called_once()
|
||||
|
||||
def test_skip_bias_add(self):
|
||||
layer = AscendMlpRowParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
self.assertIsNotNone(bias)
|
||||
|
||||
def test_no_reduce_results(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.tensor_model_parallel_all_reduce_mock.assert_not_called()
|
||||
|
||||
def test_input_not_parallel(self):
|
||||
layer = AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
input_is_parallel=False)
|
||||
input_tensor = torch.randn(16, 8)
|
||||
layer(input_tensor)
|
||||
|
||||
self.split_tensor_along_last_dim_mock.assert_called_once()
|
||||
|
||||
def test_exception_when_reduce_false_and_bias(self):
|
||||
with self.assertRaises(ValueError):
|
||||
AscendMlpRowParallelLinear(input_size=16,
|
||||
output_size=8,
|
||||
reduce_results=False,
|
||||
bias=True,
|
||||
skip_bias_add=False)
|
||||
|
||||
|
||||
class TestAscendMlpColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock distributed functions
|
||||
self.mlp_tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size')
|
||||
self.mlp_tp_size_mock = self.mlp_tp_size_patch.start()
|
||||
self.mlp_tp_size_mock.return_value = 2 # Simulate 2 GPUs in MLP TP group
|
||||
|
||||
self.mlp_tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank')
|
||||
self.mlp_tp_rank_mock = self.mlp_tp_rank_patch.start()
|
||||
self.mlp_tp_rank_mock.return_value = 0 # Current GPU rank
|
||||
|
||||
self.tp_size_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_world_size')
|
||||
self.tp_size_mock = self.tp_size_patch.start()
|
||||
self.tp_size_mock.return_value = 4 # Simulate 4 GPUs in regular TP group
|
||||
|
||||
self.tp_rank_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_tensor_model_parallel_rank')
|
||||
self.tp_rank_mock = self.tp_rank_patch.start()
|
||||
self.tp_rank_mock.return_value = 1 # Current GPU rank
|
||||
|
||||
# Mock divide function (assumed to be in your module)
|
||||
self.divide_patch = mock.patch('vllm_ascend.ops.linear.divide')
|
||||
self.divide_mock = self.divide_patch.start()
|
||||
self.divide_mock.side_effect = lambda x, y: x // y # Simulate division
|
||||
|
||||
# Mock QuantizationConfig and QuantMethod
|
||||
self.quant_config_mock = mock.MagicMock(spec=QuantizationConfig)
|
||||
|
||||
# Mock LinearBase initialization
|
||||
self.linear_base_init_patch = mock.patch.object(
|
||||
LinearBase, "__init__", side_effect=self.mock_linear_base_init)
|
||||
self.linear_base_init_patch.start()
|
||||
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
|
||||
def mock_linear_base_init(self, instance, *args, **kwargs):
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.params_dtype = mock.MagicMock()
|
||||
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.output_size_per_partition = 4
|
||||
instance.params_dtype = torch.float32
|
||||
|
||||
def tearDown(self):
|
||||
self.mlp_tp_size_patch.stop()
|
||||
self.mlp_tp_rank_patch.stop()
|
||||
self.tp_size_patch.stop()
|
||||
self.tp_rank_patch.stop()
|
||||
self.divide_patch.stop()
|
||||
self.linear_base_init_patch.stop()
|
||||
|
||||
def test_mlp_optimize_initialization(self):
|
||||
# Test when prefix contains "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.gate_up_proj",
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify MLP optimization flags
|
||||
self.assertTrue(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 2)
|
||||
self.assertEqual(layer.tp_rank, 0)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_regular_parallel_initialization(self):
|
||||
# Test when prefix does NOT contain "gate_up_proj"
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
prefix="model.layers.0.q_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify regular TP flags
|
||||
self.assertFalse(layer.enable_mlp_optimze)
|
||||
self.assertEqual(layer.tp_size, 4)
|
||||
self.assertEqual(layer.tp_rank, 1)
|
||||
self.assertEqual(layer.input_size_per_partition, 16)
|
||||
self.assertEqual(layer.output_size_per_partition, 4)
|
||||
# Check quant_method.create_weights was called
|
||||
self.quant_method_mock.create_weights.assert_called_once()
|
||||
|
||||
def test_output_sizes_handling(self):
|
||||
# Test when output_sizes is provided
|
||||
with mock.patch.object(torch.nn.Module, 'register_parameter'):
|
||||
layer = AscendMlpColumnParallelLinear(
|
||||
input_size=16,
|
||||
output_size=8,
|
||||
output_sizes=[4, 4],
|
||||
prefix="model.layers.0.qkv_proj",
|
||||
quant_config=self.quant_config_mock,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
# Verify output_partition_sizes
|
||||
self.assertEqual(layer.output_partition_sizes, [2])
|
||||
|
||||
|
||||
class TestAscendMlpMergedColumnParallelLinear(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
os.environ["VLLM_ASCEND_ENABLE_MLP_OPTIMIZE"] = "1"
|
||||
# Mock get_mlp_tensor_model_parallel_world_size and get_tensor_model_parallel_world_size
|
||||
self.mlp_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_world_size", return_value=2)
|
||||
self.tensor_world_size_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_world_size", return_value=2)
|
||||
self.mlp_world_size_patch.start()
|
||||
self.tensor_world_size_patch.start()
|
||||
|
||||
# Mock get_mlp_tensor_model_parallel_rank and get_tensor_model_parallel_rank
|
||||
self.mlp_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_mlp_tensor_model_parallel_rank", return_value=0)
|
||||
self.tensor_rank_patch = \
|
||||
mock.patch("vllm_ascend.ops.linear.get_tensor_model_parallel_rank", return_value=0)
|
||||
self.mlp_rank_patch.start()
|
||||
self.tensor_rank_patch.start()
|
||||
|
||||
# Mock all_gather methods
|
||||
self.get_mlp_tp_group_patch = \
|
||||
mock.patch('vllm_ascend.ops.linear.get_mlp_tp_group')
|
||||
self.get_mlp_tp_group_mock = self.get_mlp_tp_group_patch.start()
|
||||
self.get_mlp_tp_group_mock.return_value = mock.MagicMock()
|
||||
self.get_mlp_tp_group_mock.return_value.all_gather = mock.MagicMock()
|
||||
self.tensor_model_parallel_all_gather_patch = mock.patch(
|
||||
'vllm_ascend.ops.linear.tensor_model_parallel_all_gather',
|
||||
return_value=torch.randn(10, 8))
|
||||
self.tensor_model_parallel_all_gather_mock = \
|
||||
self.tensor_model_parallel_all_gather_patch.start()
|
||||
|
||||
# Mock AscendMlpColumnParallelLinear's __init__
|
||||
self.linear_init_patch = mock.patch.object(
|
||||
AscendMlpColumnParallelLinear,
|
||||
"__init__",
|
||||
side_effect=self.mock_linear_init)
|
||||
self.linear_init_patch.start()
|
||||
|
||||
# Create mock objects
|
||||
self.quant_method_mock = mock.MagicMock()
|
||||
self.apply_output = torch.randn(2, 8)
|
||||
|
||||
self.quant_method_mock.apply.return_value = self.apply_output
|
||||
|
||||
def mock_linear_init(self, instance, *args, **kwargs):
|
||||
torch.nn.Module.__init__(instance)
|
||||
# Set quant_method and other attributes
|
||||
instance.quant_method = self.quant_method_mock
|
||||
instance.bias = torch.nn.Parameter(torch.randn(8)) # Example bias
|
||||
instance.input_size = 16
|
||||
instance.output_size = 8
|
||||
instance.gather_output = False
|
||||
instance.skip_bias_add = False
|
||||
instance.return_bias = True
|
||||
|
||||
def test_forward_with_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix "gate_up_proj" to trigger enable_mlp_optimze = True
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def test_forward_without_enable_mlp_optimze(self):
|
||||
# Setup input
|
||||
input_tensor = torch.randn(1, 16)
|
||||
|
||||
# Create instance with prefix not containing "gate_up_proj"
|
||||
layer = AscendMlpMergedColumnParallelLinear(input_size=16,
|
||||
output_sizes=[8],
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=torch.float32,
|
||||
quant_config=None,
|
||||
prefix="other_proj")
|
||||
|
||||
# Call forward
|
||||
output, bias = layer(input_tensor)
|
||||
|
||||
# Validate calls
|
||||
self.quant_method_mock.apply.assert_called_once_with(
|
||||
layer, input_tensor, layer.bias)
|
||||
self.tensor_model_parallel_all_gather_mock.assert_not_called()
|
||||
self.assertEqual(output.shape, self.apply_output.shape)
|
||||
|
||||
def tearDown(self):
|
||||
self.linear_init_patch.stop()
|
||||
self.mlp_world_size_patch.stop()
|
||||
self.tensor_world_size_patch.stop()
|
||||
self.mlp_rank_patch.stop()
|
||||
self.tensor_rank_patch.stop()
|
||||
self.get_mlp_tp_group_mock.stop()
|
||||
self.tensor_model_parallel_all_gather_mock.stop()
|
||||
318
tests/ut/ops/test_rotary_embedding.py
Normal file
318
tests/ut/ops/test_rotary_embedding.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import math
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.layers.rotary_embedding import (
|
||||
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.rotary_embedding import _custom_rotary_embedding_enabled
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = _custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch('vllm_ascend.ops.rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = _custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestAscendRotaryEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.rotary_dim = self.head_size
|
||||
self.max_position = 16
|
||||
self.rope_theta = 10000
|
||||
self.is_neox_style = True
|
||||
self.cos_sin_cache = torch.randn(3, 1, 32)
|
||||
self.layer = RotaryEmbedding(self.head_size, self.rotary_dim,
|
||||
self.max_position, self.rope_theta,
|
||||
self.is_neox_style, torch.float16)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = self.is_neox_style
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch('vllm_ascend.ops.rotary_embedding.is_310p', return_value=False)
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
|
||||
mock__c.rotary_embedding.assert_called_once()
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
non_contig_query,
|
||||
non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
def test_rope_forward_oot_with_offsets(self):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.layer.forward(self.positions, self.query, self.key, offsets)
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = self.layer.forward(self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestAscendDeepseekScalingRotaryEmbedding(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 1, 32, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.rotary_dim = self.head_size
|
||||
self.max_position = 16
|
||||
self.rope_theta = 10000
|
||||
self.is_neox_style = True
|
||||
self.scaling_factor = 1
|
||||
self.layer = None
|
||||
|
||||
def _create_layer(self):
|
||||
self.layer = DeepseekScalingRotaryEmbedding(
|
||||
self.head_size, self.rotary_dim, self.max_position,
|
||||
self.rope_theta, self.is_neox_style, self.scaling_factor,
|
||||
torch.float16)
|
||||
return self.layer
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_base(self, mock_npuplatform):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
with patch("vllm_ascend.ops.rotary_embedding._rope_forward_oot",
|
||||
return_value=(self.query,
|
||||
self.key)) as mock_rope_forward_oot:
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query,
|
||||
self.key)
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
self.layer.max_seq_len = 1024
|
||||
# Test cache situation is true
|
||||
with patch.object(self.layer, "_set_cos_sin_cache") as mock_set_cache:
|
||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
max_seq_len=2048)
|
||||
mock_set_cache.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
key = torch.randn(1, 32)
|
||||
|
||||
mock_rope_forward_oot.return_value = (self.query, key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query, key)
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch('vllm_ascend.ops.rotary_embedding._rope_forward_oot')
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_npuplatform, mock_rope_forward_oot):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
mock_rope_forward_oot.return_value = (self.query, self.key)
|
||||
|
||||
q_pe, k_pe = self.layer.forward(self.positions, self.query, self.key)
|
||||
|
||||
mock_rope_forward_oot.assert_called_once()
|
||||
assert q_pe.shape == self.query.shape
|
||||
assert k_pe.shape == self.key.shape
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_basic_case(self, mock_npuplatform):
|
||||
# Test with standard values
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = self.layer._yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
@patch("vllm.platforms.current_platform.device_type",
|
||||
new=torch.device("cpu"))
|
||||
@patch("vllm_ascend.ops.rotary_embedding.NPUPlatform",
|
||||
new_callable=PropertyMock)
|
||||
def test_yarn_get_mscale(self, mock_npuplatform):
|
||||
mock_npuplatform.device_type = torch.device("cpu")
|
||||
self.layer = self._create_layer()
|
||||
|
||||
# test_scale_less_than_or_equal_1
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(self.layer._yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
# test_scale_greater_than_1:
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = self.layer._yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
606
tests/ut/ops/test_token_dispatcher.py
Normal file
606
tests/ut/ops/test_token_dispatcher.py
Normal file
@@ -0,0 +1,606 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
AscendSocVersion, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2, _Dispatchers,
|
||||
_register_token_dispatcher, get_token_dispatcher, setup_token_dispatchers)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mc2_group = MagicMock()
|
||||
self.mc2_group.device_group.return_value._get_backend.return_value.get_hccl_comm_name.return_value = "hccl_123"
|
||||
self.mc2_group.rank_in_group = 0
|
||||
self.mc2_group.world_size = 8
|
||||
self.mc2_group_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_mc2_group",
|
||||
return_value=self.mc2_group)
|
||||
self.mc2_group_patch.start()
|
||||
|
||||
self.rank_group_patch = patch("torch.distributed.get_rank",
|
||||
return_value=0)
|
||||
self.rank_group_patch.start()
|
||||
|
||||
# Mock get_forward_context().mc2_mask
|
||||
self.forward_context = MagicMock()
|
||||
self.forward_context.mc2_mask = torch.tensor([1, 0, 1])
|
||||
self.forward_context_patch = patch(
|
||||
"vllm.forward_context.get_forward_context",
|
||||
return_value=self.forward_context)
|
||||
self.forward_context_patch.start()
|
||||
|
||||
# Mock get_ascend_soc_version()
|
||||
self.ascend_soc_version_patch = patch(
|
||||
"vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_soc_version",
|
||||
return_value=AscendSocVersion.A3)
|
||||
self.ascend_soc_version_patch.start()
|
||||
|
||||
kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128}
|
||||
self.dispatcher = TokenDispatcherWithMC2(**kwargs)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.mc2_group_patch.stop()
|
||||
self.forward_context_patch.stop()
|
||||
self.ascend_soc_version_patch.stop()
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
self.assertTrue(self.dispatcher.a3_need_extra_args)
|
||||
|
||||
def test_get_dispatch_mc2_kwargs_without_quant(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
|
||||
def test_token_permutation_dispatch(self):
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_weights = torch.randn(10, 1)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, expert_map)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output["group_list_type"],
|
||||
1) # group_list_type == 1
|
||||
|
||||
def test_token_dispatch_with_shared_experts_and_quant(self):
|
||||
self.shared_experts = MagicMock()
|
||||
self.shared_experts.gate_up_proj.return_value = (torch.randn(10, 128),
|
||||
torch.tensor(1.0))
|
||||
self.shared_experts.act_fn.return_value = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.swiglu_out_scale = torch.tensor(1.0)
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
self.topk_weights = torch.randn(10, 1)
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
self.row_idx,
|
||||
torch.tensor(
|
||||
[0, 1, 2, 3, 4, 5, 6, 7]),
|
||||
shared_experts=self.shared_experts)
|
||||
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
def test_token_combine_with_shared_experts(self):
|
||||
self.dispatcher.shared_experts = MagicMock()
|
||||
self.dispatcher.shared_experts.down_proj.return_value = (torch.randn(
|
||||
10, 128), torch.tensor(1.0))
|
||||
self.dispatcher.shared_act = torch.randn(10, 128)
|
||||
self.dispatcher.with_quant = True
|
||||
self.dispatcher.topk_ids = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.topk_weights = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.swiglu_out_scale = torch.randint(0, 8, (10, 1))
|
||||
self.dispatcher.output = torch.randint(0, 8, (10, 1))
|
||||
self.hidden_states = torch.randn(10, 128)
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_combine_v2",
|
||||
return_value=torch.randn(10, 128)):
|
||||
self.dispatcher.token_combine(self.hidden_states)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAllGather(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Mock dependencies
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
"with_quant": False,
|
||||
}
|
||||
self.dispatcher = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
# Mock NPU functions
|
||||
self.patcher_moe_init_routing = patch('torch_npu.npu_moe_init_routing')
|
||||
self.mock_moe_init_routing = self.patcher_moe_init_routing.start()
|
||||
self.mock_moe_init_routing.return_value = (
|
||||
torch.randn(6, 128), # sorted_hidden_states
|
||||
torch.tensor([0, 1, 2, 3, 4, 5]), # expanded_row_idx
|
||||
torch.tensor([0, 1, 0, 1, 0, 1]) # expanded_expert_idx
|
||||
)
|
||||
|
||||
self.patcher_moe_compute_expert_tokens = patch(
|
||||
'torch_npu.npu_moe_compute_expert_tokens')
|
||||
self.mock_moe_compute_expert_tokens = self.patcher_moe_compute_expert_tokens.start(
|
||||
)
|
||||
self.mock_moe_compute_expert_tokens.return_value = torch.tensor(
|
||||
[3, 3]) # expert_tokens
|
||||
|
||||
self.patcher_moe_finalize_routing = patch(
|
||||
'torch_npu.npu_moe_finalize_routing')
|
||||
self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start(
|
||||
)
|
||||
self.mock_moe_finalize_routing.return_value = torch.randn(3, 128)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def tearDown(self):
|
||||
self.patcher_moe_init_routing.stop()
|
||||
self.patcher_moe_compute_expert_tokens.stop()
|
||||
self.patcher_moe_finalize_routing.stop()
|
||||
|
||||
def test_token_dispatch_without_expert_map(self):
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, self.row_idx, None)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_moe_init_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_init_routing.call_args
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
kwargs = {
|
||||
"apply_router_weight_on_input": False,
|
||||
"top_k": 2,
|
||||
"max_num_tokens": 100,
|
||||
"ep_size": 2,
|
||||
"num_experts": 128,
|
||||
}
|
||||
self.dispatcher_quant = TokenDispatcherWithAllGather(**kwargs)
|
||||
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, None)
|
||||
|
||||
self.assertEqual(results["group_list_type"], 0)
|
||||
|
||||
def test_token_combine_with_expert_map(self):
|
||||
self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify index_add_ is applied correctly
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_combine_without_expert_map(self):
|
||||
self.dispatcher.with_quant = False
|
||||
self.dispatcher.expanded_row_idx = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
self.dispatcher.sorted_token_indices = torch.tensor([0, 1, 1, 1, 1, 1])
|
||||
self.dispatcher.sorted_weights = torch.tensor(
|
||||
[0.5, 0.5, 0.5, 0.5, 0.5, 0.5])
|
||||
self.dispatcher.original_shape = (3, 128)
|
||||
self.dispatcher.mask = torch.tensor([0, 1, 1, 0])
|
||||
hidden_states = torch.randn(6, 128)
|
||||
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states)
|
||||
|
||||
# Verify npu_moe_finalize_routing is called
|
||||
self.mock_moe_finalize_routing.assert_called_once()
|
||||
args, kwargs = self.mock_moe_finalize_routing.call_args
|
||||
|
||||
self.assertEqual(final_hidden_states.shape, (3, 128))
|
||||
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
self.assertEqual(results["hidden_states"].shape, (6, 128))
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Patch properties
|
||||
patcher1 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_group',
|
||||
new_callable=PropertyMock,
|
||||
return_value=MagicMock())
|
||||
patcher2 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_rank',
|
||||
new_callable=PropertyMock,
|
||||
return_value=0)
|
||||
patcher3 = patch.object(TokenDispatcherWithAll2AllV,
|
||||
'ep_size',
|
||||
new_callable=PropertyMock,
|
||||
return_value=2)
|
||||
|
||||
self.addCleanup(patcher1.stop)
|
||||
self.addCleanup(patcher2.stop)
|
||||
self.addCleanup(patcher3.stop)
|
||||
|
||||
self.mock_ep_group_prop = patcher1.start()
|
||||
self.mock_ep_rank_prop = patcher2.start()
|
||||
self.mock_ep_size_prop = patcher3.start()
|
||||
|
||||
# Mock torch_npu.npu_moe_token_permute
|
||||
patcher4 = patch('torch_npu.npu_moe_token_permute')
|
||||
self.mock_npu_moe_token_permute = patcher4.start()
|
||||
self.addCleanup(patcher4.stop)
|
||||
self.mock_npu_moe_token_permute.return_value = (torch.randn(16, 16),
|
||||
torch.arange(16))
|
||||
|
||||
# Mock torch_npu.npu_moe_token_unpermute
|
||||
patcher5 = patch('torch_npu.npu_moe_token_unpermute')
|
||||
self.mock_npu_moe_token_unpermute = patcher5.start()
|
||||
self.addCleanup(patcher5.stop)
|
||||
self.mock_npu_moe_token_unpermute.return_value = torch.randn(8, 16)
|
||||
|
||||
# Mock async_all_to_all
|
||||
patcher6 = patch('vllm_ascend.ops.comm_utils.async_all_to_all')
|
||||
self.mock_async_all_to_all = patcher6.start()
|
||||
self.addCleanup(patcher6.stop)
|
||||
self.mock_async_all_to_all.return_value = (None, torch.randn(16, 16),
|
||||
MagicMock())
|
||||
|
||||
# Mock gather_from_sequence_parallel_region
|
||||
patcher7 = patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.gather_from_sequence_parallel_region'
|
||||
)
|
||||
self.mock_gather_from_sequence_parallel_region = patcher7.start()
|
||||
self.addCleanup(patcher7.stop)
|
||||
self.mock_gather_from_sequence_parallel_region.return_value = torch.tensor(
|
||||
[[2, 2, 2, 2], [2, 2, 2, 2]], dtype=torch.int64)
|
||||
|
||||
# Mock torch.histc
|
||||
patcher8 = patch('torch.histc')
|
||||
self.mock_histc = patcher8.start()
|
||||
self.addCleanup(patcher8.stop)
|
||||
self.mock_histc.return_value = torch.tensor([2, 2, 2, 2],
|
||||
dtype=torch.int64)
|
||||
|
||||
# Mock torch.npu.current_device
|
||||
patcher9 = patch('torch.npu.current_device')
|
||||
self.mock_current_device = patcher9.start()
|
||||
self.addCleanup(patcher9.stop)
|
||||
self.mock_current_device.return_value = 'cpu'
|
||||
|
||||
# Mock torch_npu.npu_dynamic_quant
|
||||
patcher10 = patch('torch_npu.npu_dynamic_quant')
|
||||
self.mock_npu_dynamic_quant = patcher10.start()
|
||||
self.addCleanup(patcher10.stop)
|
||||
self.mock_npu_dynamic_quant.return_value = (torch.randn(16, 16),
|
||||
torch.randn(16))
|
||||
|
||||
# Mock torch_npu.npu_moe_init_routing_v2
|
||||
patcher11 = patch('torch_npu.npu_moe_init_routing_v2')
|
||||
self.mock_npu_moe_init_routing_v2 = patcher11.start()
|
||||
self.addCleanup(patcher11.stop)
|
||||
self.mock_npu_moe_init_routing_v2.return_value = (torch.randn(
|
||||
16, 16), torch.arange(16), None, torch.randn(16))
|
||||
|
||||
# Mock torch.repeat_interleave
|
||||
patcher12 = patch('torch.repeat_interleave')
|
||||
self.mock_repeat_interleave = patcher12.start()
|
||||
self.addCleanup(patcher12.stop)
|
||||
self.mock_repeat_interleave.return_value = torch.arange(16)
|
||||
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2,
|
||||
with_quant=False)
|
||||
self.row_idx = torch.arange(10, dtype=torch.int32)
|
||||
|
||||
def test_token_dispatch(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_combine(self):
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
self.dispatcher.reversed_local_input_permutation_mapping = torch.arange(
|
||||
8)
|
||||
self.dispatcher.topk_weights = torch.rand(8, 4)
|
||||
self.dispatcher.input_splits = [4, 4]
|
||||
self.dispatcher.output_splits = [4, 4]
|
||||
self.dispatcher.reversed_global_input_permutation_mapping = torch.arange(
|
||||
16)
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
self.dispatcher.num_global_tokens_per_local_expert = torch.tensor(
|
||||
[[2, 2], [2, 2]], dtype=torch.int64)
|
||||
|
||||
expert_output = torch.randn(16, 16)
|
||||
output = self.dispatcher.token_combine(expert_output)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
def test_token_dispatch_with_quant(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_quant_no_active_tokens(self):
|
||||
self.dispatcher = TokenDispatcherWithAll2AllV(top_k=2,
|
||||
num_experts=4,
|
||||
num_local_experts=2)
|
||||
|
||||
self.mock_repeat_interleave.return_value = torch.tensor(
|
||||
[], dtype=torch.long)
|
||||
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertIsNotNone(result["dynamic_scale"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
def test_token_dispatch_with_log2phy(self):
|
||||
hidden_states = torch.randn(8, 16)
|
||||
topk_weights = torch.rand(8, 4)
|
||||
topk_ids = torch.randint(0, 4, (8, 2)).long()
|
||||
expert_map = torch.tensor([0, 1, 2, 3])
|
||||
log2phy = torch.tensor([1, 0, 3, 2])
|
||||
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=self.row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy)
|
||||
|
||||
self.assertIsNotNone(result["hidden_states"])
|
||||
self.assertIsNotNone(result["group_list"])
|
||||
self.assertEqual(result["group_list_type"], 1)
|
||||
|
||||
|
||||
class TestDispatcherRegistry(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def tearDown(self):
|
||||
_Dispatchers.clear()
|
||||
|
||||
def test_register_and_get_token_dispatcher(self):
|
||||
mock_dispatcher = MagicMock()
|
||||
mock_dispatcher.__class__.__name__ = "MockDispatcher"
|
||||
|
||||
_register_token_dispatcher(mock_dispatcher)
|
||||
|
||||
self.assertIn("MockDispatcher", _Dispatchers)
|
||||
self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher)
|
||||
|
||||
retrieved_dispatcher = get_token_dispatcher("MockDispatcher")
|
||||
self.assertIs(retrieved_dispatcher, mock_dispatcher)
|
||||
|
||||
self.assertIsNone(get_token_dispatcher("NonExistentDispatcher"))
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_1_creates_allgather(
|
||||
self, mock_register, mock_allgather_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 8}
|
||||
mock_instance = MagicMock()
|
||||
mock_allgather_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=1, **kwargs)
|
||||
|
||||
mock_allgather_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_2_creates_all2allv(
|
||||
self, mock_register, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2}
|
||||
mock_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=2, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_register.assert_called_once_with(mock_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_all2allv_instance = MagicMock()
|
||||
mock_mc2_instance = MagicMock()
|
||||
mock_all2allv_class.return_value = mock_all2allv_instance
|
||||
mock_mc2_class.return_value = mock_mc2_instance
|
||||
|
||||
self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers)
|
||||
self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers)
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_called_once_with(**kwargs)
|
||||
mock_mc2_class.assert_called_once_with(**kwargs)
|
||||
self.assertEqual(mock_register.call_count, 2)
|
||||
mock_register.assert_any_call(mock_all2allv_instance)
|
||||
mock_register.assert_any_call(mock_mc2_instance)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher'
|
||||
)
|
||||
def test_setup_token_dispatchers_ep_size_16_skips_if_exist(
|
||||
self, mock_register, mock_mc2_class, mock_all2allv_class):
|
||||
kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2}
|
||||
mock_existing_all2allv = MagicMock()
|
||||
mock_existing_mc2 = MagicMock()
|
||||
_Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv
|
||||
_Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2
|
||||
|
||||
setup_token_dispatchers(ep_size=16, **kwargs)
|
||||
|
||||
mock_all2allv_class.assert_not_called()
|
||||
mock_mc2_class.assert_not_called()
|
||||
mock_register.assert_not_called()
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"],
|
||||
mock_existing_all2allv)
|
||||
self.assertIs(_Dispatchers["TokenDispatcherWithMC2"],
|
||||
mock_existing_mc2)
|
||||
232
tests/ut/ops/test_vocab_parallel_embedding.py
Normal file
232
tests/ut/ops/test_vocab_parallel_embedding.py
Normal file
@@ -0,0 +1,232 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/lora/test_layers.py
|
||||
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.vocab_parallel_embedding import (
|
||||
AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding)
|
||||
|
||||
VOCAB_PARALLEL_EMBEDDING_TEST_NUM_RANDOM_SEEDS = 128
|
||||
|
||||
|
||||
class TestCustomVocabParallelEmbedding(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.num_embeddings = 50
|
||||
self.embedding_dim = 10
|
||||
self.org_num_embeddings = 40
|
||||
self.padding_size = 8
|
||||
|
||||
def _create_layer(self):
|
||||
# Patch methods and dependencies for VocabParallelEmbedding
|
||||
mock_group = MagicMock()
|
||||
mock_group.world_size = 2
|
||||
mock_group.rank_in_group = 0
|
||||
with patch("vllm_ascend.ops.vocab_parallel_embedding.get_tp_group", return_value=mock_group), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", return_value=2), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.pad_vocab_size", side_effect=lambda x, y: x + y), \
|
||||
patch("vllm.model_executor.layers.vocab_parallel_embedding.divide", side_effect=lambda x, y: x // y):
|
||||
|
||||
# Create an instance of VocabParallelEmbedding
|
||||
layer = AscendVocabParallelEmbedding(
|
||||
num_embeddings=self.num_embeddings,
|
||||
embedding_dim=self.embedding_dim,
|
||||
org_num_embeddings=self.org_num_embeddings,
|
||||
padding_size=self.padding_size,
|
||||
quant_config=None, # Mock quantization config
|
||||
prefix="")
|
||||
|
||||
layer.shard_indices = MagicMock()
|
||||
layer.shard_indices.org_vocab_start_index = 10
|
||||
layer.shard_indices.org_vocab_end_index = 20
|
||||
layer.shard_indices.num_org_vocab_padding = 5
|
||||
layer.shard_indices.added_vocab_start_index = 30
|
||||
layer.shard_indices.added_vocab_end_index = 40
|
||||
|
||||
# Mock the quantization method
|
||||
layer.quant_method.embedding = MagicMock(
|
||||
side_effect=lambda _, x: torch.randn(x.shape[0], self.
|
||||
embedding_dim))
|
||||
return layer
|
||||
|
||||
def test_get_masked_input_and_mask(self):
|
||||
"""Test the mask and offset calculation helper function."""
|
||||
layer = self._create_layer()
|
||||
|
||||
input_ = torch.tensor([5, 15, 25, 35, 45])
|
||||
|
||||
masked_input, mask = layer._get_masked_input_and_mask(
|
||||
input_,
|
||||
org_vocab_start_index=10,
|
||||
org_vocab_end_index=20,
|
||||
num_org_vocab_padding=5,
|
||||
added_vocab_start_index=30,
|
||||
added_vocab_end_index=40)
|
||||
|
||||
expected_mask = torch.tensor([True, False, True, False, True])
|
||||
self.assertTrue(
|
||||
torch.equal(mask, expected_mask),
|
||||
f"Mask mismatch. Expected {expected_mask}, got {mask}")
|
||||
|
||||
expected_masked = torch.tensor([0, 5, 0, 20, 0])
|
||||
self.assertTrue(
|
||||
torch.equal(masked_input, expected_masked),
|
||||
f"Masked input mismatch. Expected {expected_masked}, got {masked_input}"
|
||||
)
|
||||
|
||||
def test_forward_with_tp_size_1(self):
|
||||
"""Test forward pass without tensor parallelism."""
|
||||
# Create a fresh mock embedding with tp_size=1
|
||||
layer = self._create_layer()
|
||||
layer.tp_size = 1
|
||||
layer.quant_method.embedding = MagicMock(
|
||||
return_value=torch.randn(3, layer.embedding_dim))
|
||||
|
||||
input_ = torch.tensor([1, 2, 3])
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp1:
|
||||
output = layer.forward(input_)
|
||||
|
||||
# Should just pass through without masking
|
||||
layer.quant_method.embedding.assert_called_once_with(
|
||||
layer, input_.long())
|
||||
self.assertEqual(output.shape, (3, layer.embedding_dim))
|
||||
|
||||
# Verify all_reduce was called once
|
||||
mock_reduce_tp1.assert_called_once()
|
||||
|
||||
def test_forward_with_tp(self):
|
||||
layer = self._create_layer()
|
||||
layer.tp_size = 2
|
||||
|
||||
input_ = torch.tensor([15, 35]) # one org vocab, one added vocab
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x) as mock_reduce_tp:
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
|
||||
# Check that masking was applied correctly
|
||||
layer.quant_method.embedding.assert_called_once()
|
||||
called_input = layer.quant_method.embedding.call_args[0][1]
|
||||
expected_input = torch.tensor([5, 20]) # after offset calculation
|
||||
self.assertTrue(torch.all(called_input == expected_input))
|
||||
|
||||
# Check that all reduce was called
|
||||
mock_reduce_tp.assert_called_once()
|
||||
self.assertEqual(output.shape, (2, self.embedding_dim))
|
||||
|
||||
def test_forward_with_invalid_vocab(self):
|
||||
"""Test that invalid vocab indices are properly masked out."""
|
||||
# Create a fresh embedding layer
|
||||
layer = self._create_layer()
|
||||
input_ = torch.tensor([5, 15, 25, 35, 45]) # includes invalid cases
|
||||
# Create predictable mock output
|
||||
mock_output = torch.randn(5, self.embedding_dim)
|
||||
layer.quant_method.embedding = MagicMock(
|
||||
return_value=mock_output.clone())
|
||||
|
||||
# Patch tensor_model_parallel_all_reduce to mock its behavior
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
# Check that invalid positions (0, 2, 4) were zeroed out
|
||||
self.assertTrue(torch.all(output[0] == 0))
|
||||
self.assertTrue(torch.all(output[2] == 0))
|
||||
self.assertTrue(torch.all(output[4] == 0))
|
||||
self.assertTrue(torch.all(output[1] == mock_output[1]))
|
||||
self.assertTrue(torch.all(output[3] == mock_output[3]))
|
||||
self.assertEqual(output.shape, (5, self.embedding_dim))
|
||||
|
||||
def test_output_shape(self):
|
||||
"""Test that output shape is correct."""
|
||||
# Create a fresh embedding layer
|
||||
layer = self._create_layer()
|
||||
|
||||
test_cases = [
|
||||
(torch.tensor([15]), (1, self.embedding_dim)),
|
||||
(torch.tensor([15, 35]), (2, self.embedding_dim)),
|
||||
(torch.tensor([15, 35, 16, 36]), (4, self.embedding_dim)),
|
||||
]
|
||||
|
||||
for input_, expected_shape in test_cases:
|
||||
with self.subTest(input=input_):
|
||||
with patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.tensor_model_parallel_all_reduce",
|
||||
side_effect=lambda x: x):
|
||||
# Call the forward method
|
||||
output = layer.forward(input_)
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
|
||||
|
||||
class TestAscendLogitsProcessor(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.vocab_size = 50
|
||||
self.num_embeddings = 50
|
||||
self.embedding_dim = 10
|
||||
self.org_num_embeddings = 40
|
||||
self.padding_size = 8
|
||||
|
||||
self.mock_group = MagicMock()
|
||||
self.mock_group.world_size = 2
|
||||
self.mock_group.rank_in_group = 0
|
||||
self.mock_ascend_config = MagicMock()
|
||||
self.mock_quant_method = MagicMock()
|
||||
self.mock_quant_method.apply = MagicMock(
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
self.patches = [
|
||||
patch("vllm_ascend.ascend_config.get_ascend_config",
|
||||
return_value=self.mock_ascend_config),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group",
|
||||
return_value=self.mock_group),
|
||||
patch("vllm_ascend.ops.vocab_parallel_embedding.lmhead_tp_enable",
|
||||
return_value=True),
|
||||
patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.get_lmhead_tp_group.all_to_all",
|
||||
return_value=torch.randn(1, self.vocab_size))
|
||||
]
|
||||
|
||||
for p in self.patches:
|
||||
p.start()
|
||||
|
||||
def tearDown(self):
|
||||
for p in self.patches:
|
||||
p.stop()
|
||||
|
||||
def test_create_processor(self):
|
||||
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
|
||||
self.assertEqual(processor.vocab_size, self.vocab_size)
|
||||
|
||||
def test_get_logits(self):
|
||||
processor = AscendLogitsProcessor(vocab_size=self.vocab_size)
|
||||
lmhead = AscendParallelLMHead(num_embeddings=self.num_embeddings,
|
||||
embedding_dim=self.embedding_dim,
|
||||
prefix="lm_head")
|
||||
lmhead.quant_method = self.mock_quant_method
|
||||
lmhead.quant_method.apply = self.mock_quant_method.apply
|
||||
hidden_state = torch.randn(1, self.org_num_embeddings)
|
||||
processor._get_logits(hidden_state, lmhead)
|
||||
self.mock_quant_method.apply.assert_called_once()
|
||||
112
tests/ut/patch/worker/patch_common/test_patch_distributed.py
Normal file
112
tests/ut/patch/worker/patch_common/test_patch_distributed.py
Normal file
@@ -0,0 +1,112 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.patch.worker.patch_common.patch_distributed import \
|
||||
GroupCoordinatorPatch
|
||||
|
||||
|
||||
class TestPatchDistributed(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_group_ranks = [[0, 1]]
|
||||
self.mock_local_rank = 0
|
||||
self.mock_backend = "hccl"
|
||||
self.mock_use_device_comm = True
|
||||
|
||||
patcher_get_rank = patch("torch.distributed.get_rank", return_value=0)
|
||||
patcher_new_group = patch("torch.distributed.new_group",
|
||||
return_value=MagicMock())
|
||||
patcher_is_cuda_alike = patch(
|
||||
"vllm.platforms.current_platform.is_cuda_alike", return_value=True)
|
||||
patcher_device_comm_cls = patch(
|
||||
"vllm.distributed.parallel_state.resolve_obj_by_qualname",
|
||||
return_value=MagicMock())
|
||||
|
||||
self.mock_get_rank = patcher_get_rank.start()
|
||||
self.mock_new_group = patcher_new_group.start()
|
||||
self.mock_is_cuda_alike = patcher_is_cuda_alike.start()
|
||||
self.mock_resolve_obj = patcher_device_comm_cls.start()
|
||||
|
||||
self.addCleanup(patcher_get_rank.stop)
|
||||
self.addCleanup(patcher_new_group.stop)
|
||||
self.addCleanup(patcher_is_cuda_alike.stop)
|
||||
self.addCleanup(patcher_device_comm_cls.stop)
|
||||
|
||||
self.group_coordinator = GroupCoordinatorPatch(
|
||||
group_ranks=self.mock_group_ranks,
|
||||
local_rank=self.mock_local_rank,
|
||||
torch_distributed_backend=self.mock_backend,
|
||||
use_device_communicator=self.mock_use_device_comm)
|
||||
|
||||
def test_GroupCoordinator_patched(self):
|
||||
self.assertIs(GroupCoordinator, GroupCoordinatorPatch)
|
||||
|
||||
def test_all_to_all_returns_input_when_world_size_1(self):
|
||||
self.group_coordinator.world_size = 1
|
||||
input_tensor = torch.randn(2, 3)
|
||||
output = self.group_coordinator.all_to_all(input_tensor)
|
||||
self.assertTrue(torch.equal(output, input_tensor))
|
||||
|
||||
def test_all_to_all_raises_assertion_on_invalid_scatter_dim(self):
|
||||
input_tensor = torch.randn(2, 3)
|
||||
with self.assertRaises(AssertionError) as cm:
|
||||
self.group_coordinator.all_to_all(input_tensor, scatter_dim=2)
|
||||
self.assertIn("Invalid scatter dim", str(cm.exception))
|
||||
|
||||
def test_all_to_all_raises_assertion_on_invalid_gather_dim(self):
|
||||
input_tensor = torch.randn(2, 3)
|
||||
with self.assertRaises(AssertionError) as cm:
|
||||
self.group_coordinator.all_to_all(input_tensor, gather_dim=2)
|
||||
self.assertIn("Invalid gather dim", str(cm.exception))
|
||||
|
||||
def test_all_to_all_calls_device_communicator_with_correct_args(self):
|
||||
mock_communicator = MagicMock()
|
||||
self.group_coordinator.device_communicator = mock_communicator
|
||||
|
||||
input_tensor = torch.randn(2, 3)
|
||||
scatter_dim = 0
|
||||
gather_dim = 1
|
||||
scatter_sizes = [1, 1]
|
||||
gather_sizes = [1, 1]
|
||||
|
||||
self.group_coordinator.all_to_all(input_tensor,
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim,
|
||||
scatter_sizes=scatter_sizes,
|
||||
gather_sizes=gather_sizes)
|
||||
|
||||
mock_communicator.all_to_all.assert_called_once_with(
|
||||
input_tensor, scatter_dim, gather_dim, scatter_sizes, gather_sizes)
|
||||
|
||||
def test_all_to_all_calls_device_communicator_without_sizes(self):
|
||||
mock_communicator = MagicMock()
|
||||
self.group_coordinator.device_communicator = mock_communicator
|
||||
|
||||
input_tensor = torch.randn(2, 3)
|
||||
scatter_dim = 0
|
||||
gather_dim = 1
|
||||
|
||||
self.group_coordinator.all_to_all(input_tensor,
|
||||
scatter_dim=scatter_dim,
|
||||
gather_dim=gather_dim)
|
||||
|
||||
mock_communicator.all_to_all.assert_called_once_with(
|
||||
input_tensor, scatter_dim, gather_dim, None, None)
|
||||
167
tests/ut/patch/worker/patch_common/test_patch_linear.py
Normal file
167
tests/ut/patch/worker/patch_common/test_patch_linear.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from importlib import reload
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import vllm
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.patch.worker.patch_common import patch_linear
|
||||
|
||||
|
||||
class TestAscendRowParallelLinear(PytestBase):
|
||||
|
||||
def init_row_parallel_linear(self, mocker: MockerFixture):
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.AscendRowParallelLinear.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
return patch_linear.AscendRowParallelLinear(
|
||||
input_size=128,
|
||||
output_size=256,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version, expected",
|
||||
[
|
||||
("1.0.0", 1),
|
||||
("2.1.0", 1),
|
||||
],
|
||||
)
|
||||
def test_get_hcomm_info(self, version, expected, mocker: MockerFixture):
|
||||
mock_group = mocker.MagicMock()
|
||||
backend = mocker.MagicMock()
|
||||
backend.get_hccl_comm_name = lambda x: x
|
||||
mock_group._get_backend = lambda x: backend
|
||||
mock_group.get_hccl_comm_name = lambda x: x
|
||||
mocker.patch("torch.distributed.get_rank", return_value=1)
|
||||
mocker.patch(
|
||||
"torch.distributed.get_global_rank",
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch("torch.__version__", new=version)
|
||||
hcomm_info = patch_linear.AscendRowParallelLinear.get_hcomm_info(
|
||||
mock_group)
|
||||
assert hcomm_info == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"skip_bias_add, return_bias, bias, expected",
|
||||
[
|
||||
(True, False, torch.tensor(1.0), torch.tensor(14.0)),
|
||||
(False, True, torch.tensor(1.0), (torch.tensor(14.0), None)),
|
||||
(
|
||||
True,
|
||||
True,
|
||||
torch.tensor(1.0),
|
||||
(torch.tensor(14.0), torch.tensor(1.0)),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_forward(
|
||||
self,
|
||||
skip_bias_add,
|
||||
return_bias,
|
||||
bias,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
mocker_tp_group = mocker.MagicMock()
|
||||
mocker_tp_group.device_group = mocker.MagicMock()
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["skip_bias_add"] = skip_bias_add
|
||||
row_parallel_linear.__dict__["return_bias"] = return_bias
|
||||
row_parallel_linear.__dict__["bias"] = bias
|
||||
row_parallel_linear.__dict__["qyuant_method"] = mocker.MagicMock()
|
||||
row_parallel_linear.__dict__["calc_input"] = lambda x: x # noqa
|
||||
row_parallel_linear.__dict__[
|
||||
"calc_output"] = lambda x: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
ret = row_parallel_linear.forward(torch.tensor([10.0, 2.0]))
|
||||
if isinstance(ret, tuple):
|
||||
assert torch.allclose(ret[0], expected[0])
|
||||
if ret[1] is None:
|
||||
assert ret[1] == expected[1]
|
||||
else:
|
||||
assert torch.allclose(ret[1], expected[1])
|
||||
else:
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_is_parallel, expected",
|
||||
[
|
||||
(True, torch.tensor([10.0, 2.0])),
|
||||
(False, torch.tensor([10.0])),
|
||||
],
|
||||
)
|
||||
def test_calc_input(
|
||||
self,
|
||||
input_is_parallel,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["input_is_parallel"] = input_is_parallel
|
||||
input_tensor = torch.Tensor([10, 2])
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tensor_model_parallel_rank", # noqa
|
||||
return_value=0,
|
||||
)
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.split_tensor_along_last_dim", # noqa
|
||||
return_value=[torch.Tensor([10]),
|
||||
torch.Tensor([2])],
|
||||
)
|
||||
input_parallel = row_parallel_linear.calc_input(input_tensor)
|
||||
assert torch.allclose(input_parallel, expected)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"reduce_results, tp_size, expected",
|
||||
[
|
||||
(True, 2, torch.tensor(56.0)),
|
||||
(True, 1, torch.tensor(14.0)),
|
||||
(False, 2, torch.tensor(14.0)),
|
||||
],
|
||||
)
|
||||
def test_calc_output(
|
||||
self,
|
||||
reduce_results,
|
||||
tp_size,
|
||||
expected,
|
||||
mocker: MockerFixture,
|
||||
):
|
||||
quant_method = mocker.MagicMock()
|
||||
quant_method.apply = lambda self, x, bias=None: x.matmul( # noqa
|
||||
torch.tensor([1.0, 2.0]))
|
||||
row_parallel_linear = self.init_row_parallel_linear(mocker)
|
||||
row_parallel_linear.__dict__["reduce_results"] = reduce_results
|
||||
row_parallel_linear.__dict__["tp_size"] = tp_size
|
||||
row_parallel_linear.__dict__["quant_method"] = quant_method
|
||||
row_parallel_linear.__dict__["tp_rank"] = 0
|
||||
row_parallel_linear.__dict__["get_hcomm_info"] = lambda x: None # noqa
|
||||
|
||||
mocker.patch(
|
||||
"vllm_ascend.patch.worker.patch_common.patch_linear.get_tp_group",
|
||||
return_value=mocker.MagicMock(device_group=mocker.MagicMock()),
|
||||
)
|
||||
mocker.patch(
|
||||
"torch_npu.npu_mm_all_reduce_base",
|
||||
side_effect=lambda input_, weight, hccl_info, bias: input_.
|
||||
matmul( # noqa
|
||||
torch.tensor([4.0, 8.0])),
|
||||
) # noqa
|
||||
ret = row_parallel_linear.calc_output(torch.tensor([10.0, 2.0]))
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
def test_enable_allreduce_matmul(self, mocker: MockerFixture):
|
||||
mocker.patch.object(envs_ascend,
|
||||
"VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE",
|
||||
new=True)
|
||||
reload(patch_linear)
|
||||
assert envs_ascend.VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE
|
||||
assert id(vllm.model_executor.layers.linear.RowParallelLinear) == id(
|
||||
patch_linear.AscendRowParallelLinear)
|
||||
77
tests/ut/patch/worker/patch_common/test_patch_minicpm.py
Normal file
77
tests/ut/patch/worker/patch_common/test_patch_minicpm.py
Normal file
@@ -0,0 +1,77 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.patch.worker.patch_common.patch_minicpm import forward
|
||||
|
||||
|
||||
class TestPatchMiniCPM(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.mock_self = MagicMock()
|
||||
|
||||
self.mock_self.q_size = 128
|
||||
self.mock_self.kv_size = 128
|
||||
|
||||
self.mock_self.qkv_proj = MagicMock()
|
||||
self.mock_self.rotary_emb = MagicMock()
|
||||
self.mock_self.attn = MagicMock()
|
||||
self.mock_self.o_proj = MagicMock()
|
||||
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.hidden_states = torch.randn(3, 256)
|
||||
|
||||
self.mock_qkv = torch.randn(3, 384)
|
||||
self.mock_q = self.mock_qkv[:, :128]
|
||||
self.mock_k = self.mock_qkv[:, 128:256]
|
||||
self.mock_v = self.mock_qkv[:, 256:]
|
||||
|
||||
self.mock_self.qkv_proj.return_value = (self.mock_qkv, None)
|
||||
self.mock_self.rotary_emb.return_value = (self.mock_q, self.mock_k)
|
||||
self.mock_self.attn.return_value = torch.randn(3, 256)
|
||||
self.mock_self.o_proj.return_value = (torch.randn(3, 256), None)
|
||||
|
||||
def test_forward_patched(self):
|
||||
from vllm.model_executor.models.minicpm import MiniCPMAttention
|
||||
|
||||
self.assertIs(MiniCPMAttention.forward, forward)
|
||||
|
||||
def test_forward_function(self):
|
||||
result = forward(self.mock_self, self.positions, self.hidden_states)
|
||||
|
||||
self.mock_self.qkv_proj.assert_called_once_with(self.hidden_states)
|
||||
|
||||
args, _ = self.mock_self.rotary_emb.call_args
|
||||
self.assertEqual(len(args), 3)
|
||||
self.assertTrue(torch.equal(args[0], self.positions))
|
||||
self.assertTrue(torch.equal(args[1], self.mock_q))
|
||||
self.assertTrue(torch.equal(args[2], self.mock_k))
|
||||
|
||||
args, _ = self.mock_self.attn.call_args
|
||||
self.assertEqual(len(args), 3)
|
||||
self.assertTrue(torch.equal(args[0], self.mock_q))
|
||||
self.assertTrue(torch.equal(args[1], self.mock_k))
|
||||
self.assertTrue(torch.equal(args[2], self.mock_v))
|
||||
|
||||
self.mock_self.o_proj.assert_called_once_with(
|
||||
self.mock_self.attn.return_value)
|
||||
|
||||
self.assertEqual(result.shape, (3, 256))
|
||||
self.assertTrue(
|
||||
torch.equal(result, self.mock_self.o_proj.return_value[0]))
|
||||
134
tests/ut/quantization/test_func_wrapper.py
Normal file
134
tests/ut/quantization/test_func_wrapper.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.func_wrapper import (wrapper_rmsnorm_forward_oot,
|
||||
wrapper_rmsnorm_init)
|
||||
|
||||
|
||||
class MockRMSNorm:
|
||||
|
||||
def __init__(self, hidden_size: int, **extra_args):
|
||||
self.hidden_size = hidden_size
|
||||
self.weight = torch.ones(hidden_size)
|
||||
self.input_scale = 1.0
|
||||
self.input_offset = 0.0
|
||||
self.variance_epsilon = 1e-6
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
self.ignore_anti = extra_args.get('ignore_anti', True)
|
||||
|
||||
|
||||
class TestFuncWrapper(TestBase):
|
||||
|
||||
def test_wrapper_rmsnorm_init(self):
|
||||
|
||||
@wrapper_rmsnorm_init
|
||||
def init(self, hidden_size: int, **extra_args) -> None:
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
hidden_size = 128
|
||||
extra_args = {'arg1': 'value1'}
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size, **extra_args)
|
||||
init(rms_norm, hidden_size, **extra_args)
|
||||
|
||||
self.assertTrue(hasattr(rms_norm, 'ignore_anti'))
|
||||
self.assertTrue(rms_norm.ignore_anti)
|
||||
|
||||
self.assertTrue(hasattr(rms_norm, 'bias'))
|
||||
self.assertIsInstance(rms_norm.bias, torch.nn.Parameter)
|
||||
self.assertEqual(rms_norm.bias.shape, torch.Size([hidden_size]))
|
||||
self.assertFalse(rms_norm.bias.requires_grad)
|
||||
|
||||
@patch('torch_npu._npu_quant_rms_norm')
|
||||
def test_wrapper_rmsnorm_forward_oot_with_residual(
|
||||
self, mock_npu_quant_rms_norm):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
residual = torch.randn(hidden_size)
|
||||
expected_out = torch.randn(hidden_size)
|
||||
|
||||
mock_npu_quant_rms_norm.return_value = (expected_out, residual)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x, residual
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = False
|
||||
|
||||
output, res = forward_oot(rms_norm, x, residual)
|
||||
|
||||
mock_npu_quant_rms_norm.assert_called_once()
|
||||
|
||||
args, kwargs = mock_npu_quant_rms_norm.call_args
|
||||
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
||||
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
||||
self.assertEqual(args[3], rms_norm.input_scale)
|
||||
self.assertEqual(args[4], rms_norm.input_offset)
|
||||
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
||||
self.assertTrue(torch.equal(res, residual))
|
||||
|
||||
@patch('torch_npu._npu_quant_rms_norm')
|
||||
def test_wrapper_rmsnorm_forward_oot_without_residual(
|
||||
self, mock_npu_quant_rms_norm):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
expected_out = torch.randn(hidden_size)
|
||||
|
||||
mock_npu_quant_rms_norm.return_value = expected_out
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = False
|
||||
|
||||
output = forward_oot(rms_norm, x)
|
||||
|
||||
mock_npu_quant_rms_norm.assert_called_once()
|
||||
|
||||
args, kwargs = mock_npu_quant_rms_norm.call_args
|
||||
self.assertTrue(torch.equal(args[0], x))
|
||||
self.assertTrue(torch.equal(args[1], rms_norm.weight))
|
||||
self.assertTrue(torch.equal(args[2], rms_norm.bias))
|
||||
self.assertEqual(args[3], rms_norm.input_scale)
|
||||
self.assertEqual(args[4], rms_norm.input_offset)
|
||||
self.assertEqual(args[5], rms_norm.variance_epsilon)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_out))
|
||||
|
||||
def test_wrapper_rmsnorm_forward_oot_ignore_anti_with_residual(self):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
residual = torch.randn(hidden_size)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x, residual
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = True
|
||||
|
||||
output, res = forward_oot(rms_norm, x, residual)
|
||||
|
||||
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
||||
self.assertTrue(torch.equal(res, residual))
|
||||
|
||||
def test_wrapper_rmsnorm_forward_oot_ignore_anti_no_residual(self):
|
||||
hidden_size = 128
|
||||
x = torch.randn(hidden_size)
|
||||
|
||||
@wrapper_rmsnorm_forward_oot
|
||||
def forward_oot(self, x: torch.Tensor, residual: torch.Tensor = None):
|
||||
return x
|
||||
|
||||
rms_norm = MockRMSNorm(hidden_size)
|
||||
rms_norm.ignore_anti = True
|
||||
|
||||
output = forward_oot(rms_norm, x)
|
||||
|
||||
self.assertTrue(torch.equal(output, x.add_(rms_norm.bias)))
|
||||
232
tests/ut/quantization/test_quant_config.py
Normal file
232
tests/ut/quantization/test_quant_config.py
Normal file
@@ -0,0 +1,232 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
|
||||
AscendQuantConfig)
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
|
||||
class TestAscendQuantConfig(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.sample_config = {
|
||||
"weight": "INT8",
|
||||
"fa_quant_type": "C8",
|
||||
"kv_quant_type": "C8",
|
||||
"layer1.weight": "INT8",
|
||||
"layer2.weight": "FLOAT",
|
||||
"fused_layer.weight": "FLOAT",
|
||||
"fused_layer.shard1.weight": "FLOAT",
|
||||
"fused_layer.shard2.weight": "FLOAT",
|
||||
"shard1.weight": "FLOAT",
|
||||
"shard2.weight": "FLOAT",
|
||||
}
|
||||
self.ascend_config = AscendQuantConfig(self.sample_config)
|
||||
self.ascend_config.packed_modules_mapping = None
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.ascend_config.quant_description,
|
||||
self.sample_config)
|
||||
|
||||
def test_repr(self):
|
||||
repr_str = repr(self.ascend_config)
|
||||
self.assertTrue(repr_str.startswith("AscendQuantConfig:\n"))
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendQuantConfig.get_name(),
|
||||
ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
def test_get_supported_act_dtypes(self):
|
||||
supported_dtypes = AscendQuantConfig.get_supported_act_dtypes()
|
||||
self.assertEqual(len(supported_dtypes), 3)
|
||||
|
||||
def test_get_min_capability(self):
|
||||
with self.assertRaises(NotImplementedError):
|
||||
AscendQuantConfig.get_min_capability()
|
||||
|
||||
def test_get_config_filenames(self):
|
||||
filenames = AscendQuantConfig.get_config_filenames()
|
||||
self.assertEqual(filenames, ["quant_model_description.json"])
|
||||
|
||||
def test_from_config(self):
|
||||
config = AscendQuantConfig.from_config(self.sample_config)
|
||||
self.assertIsInstance(config, AscendQuantConfig)
|
||||
self.assertEqual(config.quant_description, self.sample_config)
|
||||
|
||||
@patch('torch.npu.is_available')
|
||||
def test_override_quantization_method(self, mock_is_available):
|
||||
# Test when NPU is available
|
||||
mock_is_available.return_value = True
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
self.assertEqual(result, ASCEND_QUANTIZATION_METHOD)
|
||||
|
||||
# Test when NPU is not available
|
||||
mock_is_available.return_value = False
|
||||
result = AscendQuantConfig.override_quantization_method(None, None)
|
||||
self.assertIsNone(result)
|
||||
|
||||
def test_get_quant_method_for_linear(self):
|
||||
linear_layer = MagicMock(spec=LinearBase)
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config,
|
||||
'is_layer_skipped_ascend',
|
||||
return_value=True):
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIsInstance(method, UnquantizedLinearMethod)
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendLinearMethod', return_value=MagicMock()) as mock_ascend_linear:
|
||||
|
||||
method = self.ascend_config.get_quant_method(linear_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_linear.return_value)
|
||||
mock_ascend_linear.assert_called_once_with(
|
||||
self.ascend_config, ".attn",
|
||||
self.ascend_config.packed_modules_mapping)
|
||||
|
||||
def test_get_quant_method_for_attention(self):
|
||||
attention_layer = MagicMock(spec=Attention)
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with fa_quant_type
|
||||
method = self.ascend_config.get_quant_method(
|
||||
attention_layer, ".attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
with patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod',
|
||||
return_value=MagicMock()) as mock_ascend_kvcache:
|
||||
# Test with kv_quant_type
|
||||
modified_config = {"kv_quant_type": "C8"}
|
||||
config = AscendQuantConfig(modified_config)
|
||||
config.packed_modules_mapping = None
|
||||
method = config.get_quant_method(attention_layer, "attn")
|
||||
self.assertIs(method, mock_ascend_kvcache.return_value)
|
||||
|
||||
def test_get_quant_method_for_fused_moe(self):
|
||||
fused_moe_layer = MagicMock(spec=FusedMoE)
|
||||
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
|
||||
fused_moe_layer.moe_config = MagicMock(spec=FusedMoEConfig)
|
||||
|
||||
# Test skipped layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=True), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendUnquantizedFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
|
||||
# Test quantized layer
|
||||
with patch.object(self.ascend_config, 'is_layer_skipped_ascend', return_value=False), \
|
||||
patch('vllm_ascend.quantization.quant_config.AscendFusedMoEMethod', return_value=MagicMock()) as mock_ascend_moe:
|
||||
method = self.ascend_config.get_quant_method(
|
||||
fused_moe_layer, "moe_layer")
|
||||
self.assertIs(method, mock_ascend_moe.return_value)
|
||||
|
||||
def test_is_layer_skipped_ascend(self):
|
||||
# Test non-fused layer that should be quantized
|
||||
self.assertFalse(self.ascend_config.is_layer_skipped_ascend("layer1"))
|
||||
|
||||
# Test non-fused layer that should be skipped
|
||||
self.assertTrue(self.ascend_config.is_layer_skipped_ascend("layer2"))
|
||||
|
||||
# Test fused layer
|
||||
fused_mapping = {"fused_layer": ["shard1", "shard2"]}
|
||||
self.assertTrue(
|
||||
self.ascend_config.is_layer_skipped_ascend("fused_layer",
|
||||
fused_mapping))
|
||||
|
||||
# Test inconsistent fused layer shards
|
||||
bad_config = {"shard1.weight": "FLOAT", "shard2.weight": "INT8"}
|
||||
config = AscendQuantConfig(bad_config)
|
||||
with self.assertRaises(ValueError):
|
||||
config.is_layer_skipped_ascend("fused_layer", fused_mapping)
|
||||
|
||||
def test_get_scaled_act_names(self):
|
||||
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
|
||||
|
||||
|
||||
class TestAscendKVCacheMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
self.prefix = "attention_layer"
|
||||
|
||||
# Mock the quantizer and quant_method
|
||||
self.mock_quantizer = MagicMock()
|
||||
self.mock_quant_method = MagicMock()
|
||||
|
||||
# Patch the AscendQuantizer
|
||||
self.quantizer_patcher = patch(
|
||||
'vllm_ascend.quantization.quant_config.AscendQuantizer.get_quantizer',
|
||||
return_value=self.mock_quantizer)
|
||||
self.mock_get_quantizer = self.quantizer_patcher.start()
|
||||
|
||||
self.mock_quantizer.build_attention_method.return_value = self.mock_quant_method
|
||||
|
||||
# Create instance
|
||||
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
|
||||
self.prefix)
|
||||
|
||||
def tearDown(self):
|
||||
self.quantizer_patcher.stop()
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization with proper quantizer setup."""
|
||||
self.mock_get_quantizer.assert_called_once_with(
|
||||
self.mock_quant_config.quant_description, self.prefix)
|
||||
self.mock_quantizer.build_attention_method.assert_called_once()
|
||||
|
||||
def test_create_weights(self):
|
||||
"""Test create_weights delegates to quant_method."""
|
||||
mock_layer = MagicMock()
|
||||
self.kv_cache_method.create_weights(mock_layer)
|
||||
self.mock_quant_method.create_weights.assert_called_once_with(
|
||||
mock_layer)
|
||||
|
||||
def test_process_weights_after_loading_with_method(self):
|
||||
"""Test process_weights when quant_method has the method."""
|
||||
mock_layer = MagicMock()
|
||||
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
||||
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
|
||||
mock_layer)
|
||||
|
||||
def test_process_weights_after_loading_without_method(self):
|
||||
"""Test process_weights when quant_method lacks the method."""
|
||||
# Reset mock to remove the method
|
||||
del self.mock_quant_method.process_weights_after_loading
|
||||
mock_layer = MagicMock()
|
||||
|
||||
# Should not raise exception
|
||||
self.kv_cache_method.process_weights_after_loading(mock_layer)
|
||||
|
||||
def test_apply_delegation(self):
|
||||
"""Test apply properly delegates to quant_method."""
|
||||
mock_layer = MagicMock()
|
||||
mock_query = torch.randn(1, 32, 128)
|
||||
mock_key = torch.randn(1, 32, 128)
|
||||
mock_value = torch.randn(1, 32, 128)
|
||||
mock_kv_cache = MagicMock()
|
||||
mock_attn_metadata = MagicMock()
|
||||
mock_scale = 1.0
|
||||
mock_output = torch.zeros(1, 32, 128)
|
||||
mock_attn_type = MagicMock()
|
||||
expected_result = torch.randn(1, 32, 128)
|
||||
self.mock_quant_method.apply.return_value = expected_result
|
||||
|
||||
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
|
||||
mock_value, mock_kv_cache,
|
||||
mock_attn_metadata, mock_attn_type,
|
||||
mock_scale, mock_output)
|
||||
|
||||
self.mock_quant_method.apply.assert_called_once_with(
|
||||
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
|
||||
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
|
||||
self.assertTrue(torch.equal(result, expected_result))
|
||||
145
tests/ut/quantization/test_quantizer.py
Normal file
145
tests/ut/quantization/test_quantizer.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
from vllm_ascend.quantization.quantizer import (VLLMAscendQuantizer,
|
||||
W4A8DYNAMICQuantizer,
|
||||
W8A8Quantizer)
|
||||
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE = {"test": "1"}
|
||||
|
||||
|
||||
class TestGetQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Setup common test fixtures
|
||||
self.supported_types = {
|
||||
'INT8': MagicMock(_instance=None),
|
||||
'FP16': MagicMock(_instance=None),
|
||||
'C8': MagicMock(_instance=None)
|
||||
}
|
||||
self.original_supported_types = SUPPORT_ASCEND_QUANTIZER_TYPE.copy()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.supported_types)
|
||||
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
|
||||
self.mock_quant_config.quant_description = {"some_config": "value"}
|
||||
|
||||
def tearDown(self):
|
||||
# Restore original supported types
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.clear()
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE.update(self.original_supported_types)
|
||||
|
||||
def test_get_quantizer_fa(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'fa_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_kv(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'kv_quant_type': 'C8'}
|
||||
prefix = '.attn'
|
||||
expected_type = 'C8'
|
||||
with patch.dict(
|
||||
'vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE',
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
def test_get_quantizer_linear(self):
|
||||
"""Test successful quantizer retrieval for different cases."""
|
||||
# Setup
|
||||
quant_description = {'linear_type': 'INT8'}
|
||||
prefix = 'nothing'
|
||||
expected_type = 'INT8'
|
||||
with patch('vllm_ascend.quantization.quantizer.VLLMAscendQuantizer.get_linear_quant_type',
|
||||
return_value=expected_type), \
|
||||
patch.dict('vllm_ascend.quantization.quantizer.SUPPORT_ASCEND_QUANTIZER_TYPE', SUPPORT_ASCEND_QUANTIZER_TYPE):
|
||||
|
||||
result = VLLMAscendQuantizer.get_quantizer(
|
||||
quant_description,
|
||||
prefix,
|
||||
packed_modules_mapping={"some": "mapping"})
|
||||
|
||||
# Verify
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result,
|
||||
self.supported_types[expected_type]._instance)
|
||||
self.supported_types[expected_type].assert_called_once_with(
|
||||
quant_description)
|
||||
|
||||
|
||||
class TestW8A8Quantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W8A8Quantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendW8A8LinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW8A8FusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_attention_method(self):
|
||||
with patch('vllm_ascend.quantization.quantizer.AscendC8KVCacheMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_attention_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
|
||||
class TestW4A8DYNAMICQuantizer(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.quantizer = W4A8DYNAMICQuantizer(quant_description={})
|
||||
|
||||
def test_build_linear_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicLinearMethod',
|
||||
return_value=MagicMock()) as mock_linear:
|
||||
result = self.quantizer.build_linear_method()
|
||||
mock_linear.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
|
||||
def test_build_moe_method(self):
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendW4A8DynamicFusedMoEMethod',
|
||||
return_value=MagicMock()) as mock_fused_moe:
|
||||
result = self.quantizer.build_moe_method()
|
||||
mock_fused_moe.assert_called_once_with()
|
||||
self.assertIsInstance(result, MagicMock)
|
||||
166
tests/ut/quantization/test_w4a8_dynamic.py
Normal file
166
tests/ut/quantization/test_w4a8_dynamic.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.w4a8_dynamic import (
|
||||
AscendW4A8DynamicFusedMoEMethod, AscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (32, 8))
|
||||
|
||||
def test_get_pergroup_param(self):
|
||||
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
experts = 8
|
||||
input_size = 16
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_current_vllm_config')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_ep_group')
|
||||
@patch('vllm_ascend.quantization.w4a8_dynamic.get_mc2_group')
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ep_group,
|
||||
get_current_vllm_config):
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = AscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
# old quant version w4a8 weight
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, 2 * self.input_size, self.output_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, self.input_size, self.output_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
||||
self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
||||
self.assertEqual(layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
930
tests/ut/quantization/test_w8a8.py
Normal file
930
tests/ut/quantization/test_w8a8.py
Normal file
@@ -0,0 +1,930 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
|
||||
select_experts)
|
||||
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
|
||||
AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod,
|
||||
fused_experts, fused_experts_310p,
|
||||
quant_per_tensor)
|
||||
|
||||
|
||||
class TestQuantPerTensor(TestBase):
|
||||
|
||||
@patch("torch_npu.npu_quantize")
|
||||
def test_quant_per_tensor(self, mock_npu_quantize):
|
||||
in_tensor = torch.randn(32, 128)
|
||||
input_scale = torch.tensor(0.1)
|
||||
input_offset = torch.tensor(0)
|
||||
|
||||
expected_output = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
mock_npu_quantize.return_value = expected_output
|
||||
|
||||
output = quant_per_tensor(in_tensor, input_scale, input_offset)
|
||||
|
||||
mock_npu_quantize.assert_called_once_with(
|
||||
in_tensor,
|
||||
input_scale,
|
||||
input_offset,
|
||||
torch.qint8,
|
||||
-1,
|
||||
False,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(output, expected_output))
|
||||
|
||||
|
||||
class TestAscendW8A8LinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = AscendW8A8LinearMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(10, 20)
|
||||
self.assertEqual(weight['weight'].dtype, torch.int8)
|
||||
self.assertEqual(weight['weight'].shape, (20, 10))
|
||||
|
||||
def test_get_pertensor_param(self):
|
||||
params = self.method.get_pertensor_param(torch.bfloat16)
|
||||
self.assertEqual(params['input_scale'].dtype, torch.bfloat16)
|
||||
self.assertEqual(params['input_offset'].dtype, torch.int8)
|
||||
self.assertEqual(params['input_scale'].shape, (1, ))
|
||||
self.assertEqual(params['input_offset'].shape, (1, ))
|
||||
|
||||
def test_get_perchannel_param(self):
|
||||
params = self.method.get_perchannel_param(10, torch.bfloat16)
|
||||
|
||||
self.assertEqual(params['quant_bias'].dtype, torch.int32)
|
||||
self.assertEqual(params['deq_scale'].dtype, torch.float32)
|
||||
self.assertEqual(params['weight_scale'].dtype, torch.bfloat16)
|
||||
self.assertEqual(params['weight_offset'].dtype, torch.bfloat16)
|
||||
self.assertEqual(params['quant_bias'].shape, (10, ))
|
||||
self.assertEqual(params['deq_scale'].shape, (10, ))
|
||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
|
||||
mock_quant_per_tensor):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randn(32, 128)
|
||||
bias = torch.randn(256)
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
x.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_is_int8(self, mock_npu_quant_matmul):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
bias = torch.randn(256)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_is_310p(self, mock_npu_quant_matmul, mock_is_310p):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
x = torch.randint(-128, 127, (32, 128), dtype=torch.int8)
|
||||
bias = torch.randn(256)
|
||||
|
||||
expected_y_output = torch.randn(32, 256)
|
||||
mock_npu_quant_matmul.return_value = expected_y_output
|
||||
|
||||
output = self.method.apply(layer, x, bias)
|
||||
expected_y_output += bias
|
||||
self.assertTrue(torch.equal(output, expected_y_output))
|
||||
|
||||
@patch('torch_npu.npu_format_cast')
|
||||
def test_process_weights_after_loading(self, mock_npu_format_cast):
|
||||
layer = MagicMock()
|
||||
|
||||
layer.weight.data = torch.randn(128, 256)
|
||||
layer.input_scale.data = torch.tensor([0.1])
|
||||
layer.input_offset.data = torch.tensor([0])
|
||||
layer.deq_scale = torch.tensor([0.5])
|
||||
layer.weight_scale.data = torch.randn(128, 1)
|
||||
layer.weight_offset.data = torch.randn(128, 1)
|
||||
|
||||
mock_npu_format_cast.return_value = MagicMock
|
||||
self.method.process_weights_after_loading(layer)
|
||||
|
||||
expected_offset = torch.tensor([0]).repeat(256).to(torch.int8)
|
||||
self.assertTrue(
|
||||
torch.equal(layer.aclnn_input_offset.data, expected_offset))
|
||||
self.assertFalse(layer.aclnn_input_offset.requires_grad)
|
||||
|
||||
self.assertFalse(layer.deq_scale.requires_grad)
|
||||
|
||||
self.assertEqual(layer.weight_scale.data.shape, (128, ))
|
||||
self.assertEqual(layer.weight_offset.data.shape, (128, ))
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.moe_method = AscendW8A8FusedMoEMethod()
|
||||
self.num_experts = 4
|
||||
self.intermediate_size = 64
|
||||
self.hidden_size = 128
|
||||
self.dtype = torch.float32
|
||||
|
||||
def test_init(self):
|
||||
self.assertTrue(self.moe_method.transpose_weight)
|
||||
|
||||
def test_get_weight(self):
|
||||
weights = self.moe_method.get_weight(
|
||||
num_experts=self.num_experts,
|
||||
intermediate_size_per_partition=self.intermediate_size,
|
||||
hidden_sizes=self.hidden_size,
|
||||
params_dtype=self.dtype)
|
||||
|
||||
assert "w13_weight" in weights, f"w13_weight not in {weights}"
|
||||
assert "w2_weight" in weights, f"w2_weight not in {weights}"
|
||||
self.assertEqual(
|
||||
weights["w13_weight"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
|
||||
self.assertEqual(
|
||||
weights["w2_weight"].shape,
|
||||
(self.num_experts, self.hidden_size, self.intermediate_size))
|
||||
self.assertEqual(weights["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(weights["w2_weight"].dtype, torch.int8)
|
||||
self.assertFalse(weights["w13_weight"].requires_grad)
|
||||
self.assertFalse(weights["w2_weight"].requires_grad)
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
quant_params = self.moe_method.get_dynamic_quant_param(
|
||||
num_experts=self.num_experts,
|
||||
intermediate_size_per_partition=self.intermediate_size,
|
||||
hidden_sizes=self.hidden_size,
|
||||
params_dtype=self.dtype)
|
||||
|
||||
expected_params = [
|
||||
"w13_weight_scale", "w13_weight_offset", "w2_weight_scale",
|
||||
"w2_weight_offset", "w2_deq_scale", "w13_deq_scale",
|
||||
"w2_input_scale", "w13_input_scale", "w2_input_offset",
|
||||
"w13_input_offset", "quant_bias"
|
||||
]
|
||||
|
||||
for param in expected_params:
|
||||
assert param in quant_params, f"{param} not in {quant_params}"
|
||||
|
||||
# Check some sample shapes
|
||||
self.assertEqual(quant_params["w13_weight_scale"].shape,
|
||||
(self.num_experts, 2 * self.intermediate_size, 1))
|
||||
self.assertEqual(quant_params["w2_input_offset"].shape,
|
||||
(self.num_experts, 1))
|
||||
self.assertEqual(quant_params["quant_bias"].shape,
|
||||
(self.num_experts, self.hidden_size))
|
||||
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts')
|
||||
def test_apply_with_other_expert_count(self, mock_fused_experts,
|
||||
mock_select_experts):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts.return_value = torch.randn(32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
@patch('vllm_ascend.quantization.w8a8.select_experts')
|
||||
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
|
||||
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
|
||||
mock_is_310p):
|
||||
# Setup
|
||||
mock_layer = MagicMock()
|
||||
x = torch.randn(32, self.hidden_size)
|
||||
router_logits = torch.randn(32, 128) # 128 experts
|
||||
top_k = 2
|
||||
|
||||
# Mock return values
|
||||
mock_select_experts.return_value = (torch.randn(32, top_k),
|
||||
torch.randint(0, 128, (32, top_k)))
|
||||
mock_fused_experts_310p.return_value = torch.randn(
|
||||
32, self.hidden_size)
|
||||
|
||||
# Test
|
||||
result = self.moe_method.apply(layer=mock_layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
global_num_experts=128)
|
||||
|
||||
# Assertions
|
||||
mock_select_experts.assert_called_once()
|
||||
mock_fused_experts_310p.assert_called_once()
|
||||
self.assertEqual(result.shape, (32, self.hidden_size))
|
||||
|
||||
|
||||
class TestAscendC8KVCacheMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.layer = MagicMock()
|
||||
self.layer.num_kv_heads = 4
|
||||
self.layer.head_size = 64
|
||||
self.layer.num_heads = 8
|
||||
self.layer._k_scale_float = 1.0
|
||||
self.layer._v_scale_float = 1.0
|
||||
self.method = AscendC8KVCacheMethod()
|
||||
|
||||
self.attention_type = MagicMock()
|
||||
self.attention_type.DECODER = "decoder"
|
||||
self.attention_type.ENCODER = "encoder"
|
||||
|
||||
def test_create_weights(self):
|
||||
"""测试 create_weights 是否正确注册参数"""
|
||||
AscendC8KVCacheMethod.create_weights(self.layer)
|
||||
|
||||
self.layer.register_parameter.assert_any_call("key_antiquant_scale",
|
||||
unittest.mock.ANY)
|
||||
self.layer.register_parameter.assert_any_call("value_antiquant_scale",
|
||||
unittest.mock.ANY)
|
||||
|
||||
calls = self.layer.register_parameter.call_args_list
|
||||
|
||||
for call in calls:
|
||||
args, kwargs = call
|
||||
param = kwargs.get('parameter', args[1] if len(args) > 1 else None)
|
||||
|
||||
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
|
||||
self.assertEqual(param.shape, expected_shape)
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=False)
|
||||
def test_process_weights_after_loading_not_310p(self, mock_is_310p):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.is_310p", return_value=True)
|
||||
def test_process_weights_after_loading_is_310p(self, mock_is_310p):
|
||||
key_data = torch.ones(4 * 64)
|
||||
value_data = torch.ones(4 * 64) * 2
|
||||
|
||||
self.layer.key_antiquant_scale.data = key_data
|
||||
self.layer.value_antiquant_scale.data = value_data
|
||||
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
|
||||
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_decode_only(self, mock_quant, mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch('torch_npu.npu_scatter_nd_update_')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_attn_metadata_without_decode(self, mock_quant,
|
||||
mock_scatter):
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock(spec=[
|
||||
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
|
||||
'attn_mask'
|
||||
])
|
||||
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
|
||||
attn_metadata.slot_mapping = torch.tensor([0, 1])
|
||||
attn_metadata.attn_mask = None
|
||||
|
||||
block_size = 16
|
||||
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
|
||||
self.layer.head_size)
|
||||
kv_cache = (key_cache, value_cache)
|
||||
|
||||
mock_quant.side_effect = [key, value]
|
||||
|
||||
self.layer.key_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.layer.value_antiquant_scale.data = torch.ones(
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
self.method.process_weights_after_loading(self.layer)
|
||||
|
||||
expected_output = torch.randn(
|
||||
num_tokens, self.layer.num_heads * self.layer.head_size)
|
||||
with patch('torch_npu.npu_incre_flash_attention',
|
||||
return_value=expected_output):
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata,
|
||||
self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
self.assertEqual(mock_quant.call_count, 2)
|
||||
self.assertEqual(mock_scatter.call_count, 2)
|
||||
self.assertTrue(torch.equal(result, expected_output))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('torch_npu._npu_flash_attention')
|
||||
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
|
||||
"""Test apply method in prefill no-cache mode"""
|
||||
|
||||
num_tokens = 2
|
||||
query = torch.randn(num_tokens,
|
||||
self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(num_tokens,
|
||||
self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
attn_metadata.seq_lens = [10, 10]
|
||||
attn_metadata.attn_mask = torch.ones(2, 2)
|
||||
|
||||
kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
mock_quant.return_value = key
|
||||
|
||||
result = self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata, self.attention_type.DECODER,
|
||||
1.0, output)
|
||||
|
||||
# Check that flash attention was called
|
||||
mock_flash.assert_called_once()
|
||||
|
||||
# Check output shape
|
||||
self.assertEqual(
|
||||
result.shape,
|
||||
(num_tokens, self.layer.num_heads * self.layer.head_size))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_unsupported_attention_type(self, mock_quant):
|
||||
|
||||
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
mock_quant.return_value = key
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
|
||||
|
||||
with self.assertRaises(NotImplementedError) as cm:
|
||||
self.method.apply(self.layer, query, key, value, (None, None),
|
||||
attn_metadata, self.attention_type.ENCODER, 1.0,
|
||||
output)
|
||||
|
||||
assert "Encoder self-attention" in str(
|
||||
cm.exception), f"Encoder self-attention not in {str(cm.exception)}"
|
||||
assert "not implemented" in str(
|
||||
cm.exception), f"not implemented not in{str(cm.exception)}"
|
||||
|
||||
mock_quant.assert_not_called()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
def test_apply_unsupported_attention_state(self, mock_quant):
|
||||
"""Test apply with unsupported attention state"""
|
||||
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
|
||||
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
|
||||
output = torch.empty_like(query)
|
||||
|
||||
attn_metadata = MagicMock()
|
||||
attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit
|
||||
mock_quant.return_value = key
|
||||
kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
with self.assertRaises(NotImplementedError):
|
||||
self.method.apply(self.layer, query, key, value, kv_cache,
|
||||
attn_metadata, self.attention_type.DECODER, 1.0,
|
||||
output)
|
||||
|
||||
|
||||
class TestFusedExperts(TestBase):
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_moe_init_routing_v2')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
@patch('torch_npu.npu_moe_finalize_routing')
|
||||
def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu,
|
||||
mock_group_matmul,
|
||||
mock_init_routing,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 2
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
w1_input_offset = torch.tensor([0])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
w2_input_offset = torch.tensor([0])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 8
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_init_routing.return_value = (torch.randn(num_tokens * top_k,
|
||||
hidden_size),
|
||||
torch.arange(num_tokens * top_k),
|
||||
torch.tensor([num_tokens // 2] * 2),
|
||||
torch.tensor(1.0))
|
||||
|
||||
mock_group_matmul.side_effect = [[
|
||||
torch.randn(num_tokens * top_k, intermediate_size * 2)
|
||||
], [torch.randn(num_tokens * top_k, hidden_size)]]
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
expected_output = torch.randn(num_tokens, hidden_size)
|
||||
mock_finalize.return_value = expected_output
|
||||
|
||||
output = fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w1_input_offset=w1_input_offset,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
w2_input_offset=w2_input_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
mock_init_routing.assert_called_once()
|
||||
|
||||
self.assertEqual(mock_group_matmul.call_count, 2)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_grouped_matmul')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_without_expert_map(self, mock_swiglu,
|
||||
mock_group_matmul,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor):
|
||||
num_tokens = 16
|
||||
hidden_size = 64
|
||||
intermediate_size = 128
|
||||
num_experts = 8
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 8
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
mock_group_matmul.side_effect = [[
|
||||
torch.randn(num_tokens * top_k, intermediate_size * 2)
|
||||
], [torch.randn(num_tokens * top_k, hidden_size)]]
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
with self.assertRaises(NotImplementedError):
|
||||
fused_experts(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=torch.tensor([0.1]),
|
||||
w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]),
|
||||
w1_input_offset=torch.tensor([0]),
|
||||
w2=w2,
|
||||
w2_scale=torch.tensor([0.1]),
|
||||
w2_input_scale=torch.tensor([0.1]),
|
||||
w2_input_offset=torch.tensor([0]),
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
)
|
||||
|
||||
|
||||
class TestFusedExperts310(TestBase):
|
||||
|
||||
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
|
||||
@patch('torch_npu.npu_swiglu')
|
||||
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
|
||||
mock_get_ep_group,
|
||||
mock_quant_per_tensor,
|
||||
mock_matmul_dequant):
|
||||
num_tokens = 32
|
||||
hidden_size = 128
|
||||
intermediate_size = 256
|
||||
num_experts = 4
|
||||
top_k = 1
|
||||
|
||||
hidden_states = torch.randn(num_tokens, hidden_size)
|
||||
|
||||
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
|
||||
w1_scale = torch.tensor([0.1])
|
||||
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
|
||||
|
||||
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
|
||||
w2_scale = torch.tensor([0.1])
|
||||
w2_input_scale = torch.tensor([0.2])
|
||||
|
||||
topk_weights = torch.rand(num_tokens, top_k)
|
||||
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
|
||||
expert_map = torch.arange(num_experts)
|
||||
|
||||
mock_get_ep_group.return_value.world_size = 1
|
||||
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
127,
|
||||
hidden_states.shape,
|
||||
dtype=torch.int8)
|
||||
|
||||
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
|
||||
intermediate_size)
|
||||
|
||||
mock_matmul_dequant.return_value = hidden_states
|
||||
|
||||
output = fused_experts_310p(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w1_input_scale=w1_input_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w2_input_scale=w2_input_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
self.assertEqual(output.shape, (num_tokens, hidden_size))
|
||||
self.assertEqual(mock_matmul_dequant.call_count, 2)
|
||||
|
||||
|
||||
class TestSelectExperts(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common test data
|
||||
self.num_tokens = 10
|
||||
self.hidden_size = 32
|
||||
self.num_experts = 8
|
||||
self.top_k = 2
|
||||
|
||||
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
|
||||
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_softmax_scoring(self, mock_topk):
|
||||
"""Test softmax scoring function"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="softmax")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_sigmoid_scoring(self):
|
||||
"""Test sigmoid scoring function"""
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="sigmoid")
|
||||
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_invalid_scoring_func(self):
|
||||
"""Test invalid scoring function raises ValueError"""
|
||||
with self.assertRaises(ValueError):
|
||||
select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid_func")
|
||||
|
||||
@patch('torch.topk')
|
||||
def test_grouped_topk(self, mock_topk):
|
||||
"""Test grouped topk functionality"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long))
|
||||
|
||||
weights, ids, _ = select_experts(hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2)
|
||||
|
||||
mock_topk.assert_called()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
|
||||
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
|
||||
"""Test grouped topk with expert score correction bias"""
|
||||
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
|
||||
self.num_experts)
|
||||
|
||||
e_score_correction_bias = torch.randn(self.num_experts)
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=False,
|
||||
topk_group=4,
|
||||
num_expert_group=2,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
mock_grouped_topk.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
|
||||
def test_custom_routing_function(self):
|
||||
"""Test custom routing function"""
|
||||
mock_custom_routing = MagicMock()
|
||||
mock_custom_routing.return_value = (torch.ones(self.num_tokens,
|
||||
self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.int32))
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
custom_routing_function=mock_custom_routing)
|
||||
|
||||
mock_custom_routing.assert_called_once()
|
||||
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_renormalize(self, mock_topk):
|
||||
"""Test renormalization"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
# Check if weights are normalized (sum to 1 for each token)
|
||||
sums = weights.sum(dim=-1)
|
||||
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
|
||||
|
||||
@patch('torch_npu.npu_moe_gating_top_k_softmax')
|
||||
def test_output_dtypes(self, mock_topk):
|
||||
"""Test output dtypes"""
|
||||
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
|
||||
torch.zeros(self.num_tokens,
|
||||
self.top_k,
|
||||
dtype=torch.long),
|
||||
torch.arange(0,
|
||||
self.num_tokens * self.top_k,
|
||||
dtype=torch.int32).view(
|
||||
self.top_k,
|
||||
-1).permute(1,
|
||||
0).contiguous())
|
||||
|
||||
weights, ids, _ = select_experts(
|
||||
hidden_states=self.hidden_states,
|
||||
router_logits=self.router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
self.assertEqual(weights.dtype, self.hidden_states.dtype)
|
||||
self.assertEqual(ids.dtype, torch.int32)
|
||||
|
||||
|
||||
class TestNativeGroupedTopkPartialMock(TestBase):
|
||||
|
||||
def test_basic_group_selection(self):
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
|
||||
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1],
|
||||
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
|
||||
[0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]],
|
||||
dtype=torch.float32)
|
||||
|
||||
expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]])
|
||||
|
||||
with patch('torch.topk',
|
||||
return_value=(None, expected_topk_indices)) as mock_topk:
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=2)
|
||||
|
||||
mock_topk.assert_called_once()
|
||||
|
||||
expected_result = topk_weights
|
||||
self.assertTrue(torch.allclose(result, expected_result))
|
||||
|
||||
def test_partial_group_selection(self):
|
||||
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
|
||||
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]])
|
||||
|
||||
expected_topk_indices = torch.tensor([[0], [1]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=2,
|
||||
topk_group=1)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]])
|
||||
self.assertTrue(torch.allclose(result, expected_result))
|
||||
|
||||
def test_single_group(self):
|
||||
topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]])
|
||||
|
||||
expected_topk_indices = torch.tensor([[0], [0]])
|
||||
|
||||
with patch('torch.topk', return_value=(None, expected_topk_indices)):
|
||||
result = _native_grouped_topk(topk_weights=topk_weights,
|
||||
num_expert_group=1,
|
||||
topk_group=1)
|
||||
self.assertTrue(result.numel() > 0)
|
||||
203
tests/ut/sample/test_rejection_sampler.py
Normal file
203
tests/ut/sample/test_rejection_sampler.py
Normal file
@@ -0,0 +1,203 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.sample.rejection_sampler import (
|
||||
expand_batch_to_tokens, expand_pytorch, rejection_greedy_sample_pytorch,
|
||||
rejection_random_sample_pytorch, sample_recovered_tokens_pytorch)
|
||||
|
||||
# Global constants
|
||||
PLACEHOLDER_TOKEN_ID = -1
|
||||
GREEDY_TEMPERATURE = 0.0
|
||||
MAX_SPEC_LEN = 8 # Used as MAX_NUM_TOKENS in expand_batch_to_tokens
|
||||
|
||||
|
||||
class TestAscendRejectionSampler(TestBase):
|
||||
|
||||
def test_rejection_greedy_sample_pytorch(self):
|
||||
"""Test greedy rejection sampling: stop when draft doesn't match, otherwise append bonus token"""
|
||||
batch_size = 2
|
||||
max_spec_len = 2
|
||||
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
cu_num_draft_tokens = torch.tensor([2, 4])
|
||||
num_draft_tokens = [2, 2]
|
||||
draft_token_ids = torch.tensor([10, 11, 20, 21])
|
||||
target_argmax = torch.tensor([10, 99, 20, 22])
|
||||
bonus_token_ids = torch.tensor([[100], [200]])
|
||||
|
||||
is_greedy = torch.tensor([True, True])
|
||||
|
||||
rejection_greedy_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
target_argmax,
|
||||
bonus_token_ids,
|
||||
num_draft_tokens,
|
||||
max_spec_len,
|
||||
is_greedy,
|
||||
)
|
||||
|
||||
assert output_token_ids[0, 0].item() == 10
|
||||
assert output_token_ids[0, 1].item() == 99
|
||||
assert output_token_ids[1, 0].item() == 20
|
||||
assert output_token_ids[1, 2].item() == PLACEHOLDER_TOKEN_ID
|
||||
|
||||
def test_rejection_random_sample_pytorch(self):
|
||||
"""Test random rejection sampling: accept based on uniform probability"""
|
||||
batch_size = 2
|
||||
max_spec_len = 3
|
||||
output_token_ids = torch.full((batch_size, max_spec_len + 1),
|
||||
PLACEHOLDER_TOKEN_ID)
|
||||
|
||||
cu_num_draft_tokens = torch.tensor([2, 1])
|
||||
draft_token_ids = torch.tensor([1, 0, 2])
|
||||
draft_probs = torch.tensor([
|
||||
[0.0, 0.6, 0.0, 0.4], # vocab_size=4
|
||||
[0.1, 0.2, 0.3, 0.4],
|
||||
[0.5, 0.5, 0.0, 0.0],
|
||||
])
|
||||
target_probs = torch.tensor([
|
||||
[0.0, 0.8, 0.0, 0.2],
|
||||
[0.2, 0.1, 0.3, 0.4],
|
||||
[0.9, 0.1, 0.0, 0.0],
|
||||
])
|
||||
bonus_token_ids = torch.tensor([[100], [200]])
|
||||
recovered_token_ids = torch.tensor([1, 2, 3])
|
||||
uniform_probs = torch.tensor([0.7, 0.6, 0.5])
|
||||
is_greedy = torch.tensor([False, False])
|
||||
vocab_size = 4
|
||||
|
||||
rejection_random_sample_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
bonus_token_ids,
|
||||
recovered_token_ids,
|
||||
uniform_probs,
|
||||
is_greedy,
|
||||
max_spec_len,
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
)
|
||||
|
||||
assert output_token_ids[0, 0].item() == 1
|
||||
assert output_token_ids[0, 1].item() == 0
|
||||
assert output_token_ids[0, 2].item() == 100
|
||||
|
||||
def test_expand_pytorch(self):
|
||||
"""Test expand_pytorch functionality"""
|
||||
input_ptr = torch.tensor([10, 20, 30], dtype=torch.int32)
|
||||
cu_num_tokens_ptr = torch.tensor([2, 5, 7])
|
||||
output_ptr = torch.empty(7, dtype=torch.int32)
|
||||
|
||||
expand_pytorch(
|
||||
output_ptr,
|
||||
input_ptr,
|
||||
cu_num_tokens_ptr,
|
||||
replace_from=0,
|
||||
replace_to=0,
|
||||
MAX_NUM_TOKENS=MAX_SPEC_LEN,
|
||||
)
|
||||
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(output_ptr, expected)
|
||||
|
||||
def test_expand_batch_to_tokens(self):
|
||||
"""Test expand_batch_to_tokens wrapper"""
|
||||
x = torch.tensor([10, 20, 30])
|
||||
cu_num_tokens = torch.tensor([2, 5, 7])
|
||||
num_tokens = 7
|
||||
|
||||
with patch("vllm_ascend.sample.rejection_sampler.expand_pytorch"
|
||||
) as mock_kernel:
|
||||
expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
|
||||
mock_kernel.assert_called_once()
|
||||
args = mock_kernel.call_args[0]
|
||||
assert (args[1] == x).all()
|
||||
assert (args[2] == cu_num_tokens).all()
|
||||
|
||||
# Run actual function
|
||||
result = expand_batch_to_tokens(x, cu_num_tokens, num_tokens)
|
||||
expected = torch.tensor([10, 10, 20, 20, 20, 30, 30])
|
||||
assert torch.equal(result, expected)
|
||||
|
||||
def test_sample_recovered_tokens_pytorch_ngram(self):
|
||||
"""Test recovered token sampling under n-gram mode"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
cu_num_draft_tokens = torch.tensor([1, 2])
|
||||
draft_token_ids = torch.tensor([1, 2])
|
||||
draft_probs = None
|
||||
target_probs = torch.tensor([
|
||||
[0.1, 0.2, 0.7],
|
||||
[0.3, 0.3, 0.4],
|
||||
])
|
||||
q = torch.tensor([
|
||||
[0.1, 0.2, 0.7],
|
||||
[0.5, 0.4, 0.1],
|
||||
])
|
||||
vocab_size = 3
|
||||
|
||||
sample_recovered_tokens_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=True,
|
||||
)
|
||||
|
||||
assert output_token_ids[0].item() == 0
|
||||
assert output_token_ids[1].item() == 1
|
||||
|
||||
def test_sample_recovered_tokens_pytorch_autoregressive(self):
|
||||
"""Test recovered token sampling for autoregressive models"""
|
||||
output_token_ids = torch.empty(2, dtype=torch.int32)
|
||||
cu_num_draft_tokens = torch.tensor([1, 1])
|
||||
draft_token_ids = torch.tensor([0, 1])
|
||||
draft_probs = torch.tensor([
|
||||
[0.6, 0.1, 0.3],
|
||||
[0.2, 0.7, 0.1],
|
||||
])
|
||||
target_probs = torch.tensor([
|
||||
[0.8, 0.1, 0.1],
|
||||
[0.3, 0.6, 0.1],
|
||||
])
|
||||
q = torch.tensor([
|
||||
[0.5, 0.3, 0.2],
|
||||
[0.1, 0.8, 0.1],
|
||||
])
|
||||
vocab_size = 3
|
||||
|
||||
sample_recovered_tokens_pytorch(
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
q,
|
||||
vocab_size,
|
||||
IS_NGRAM=False,
|
||||
)
|
||||
assert output_token_ids[0].item() == 0
|
||||
32
tests/ut/sample/test_sampler.py
Normal file
32
tests/ut/sample/test_sampler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.sample.sampler import AscendSampler, AscendTopKTopPSampler
|
||||
|
||||
|
||||
class TestAscendSampler(TestBase):
|
||||
|
||||
def test_init_with_raw_logprobs(self):
|
||||
sampler = AscendSampler(logprobs_mode="raw_logprobs")
|
||||
self.assertEqual(sampler.logprobs_mode, "raw_logprobs")
|
||||
self.assertTrue(hasattr(sampler, 'topk_topp_sampler'))
|
||||
self.assertIsInstance(sampler.topk_topp_sampler, AscendTopKTopPSampler)
|
||||
|
||||
|
||||
class TestAscendTopKTopPSampler(TestBase):
|
||||
|
||||
@mock.patch("torch_npu.npu_top_k_top_p")
|
||||
def test_npu_topk_topp_called_when_optimized(self, mock_npu_op):
|
||||
mock_npu_op.return_value = (torch.randn(1, 3))
|
||||
sampler = AscendTopKTopPSampler()
|
||||
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
k = torch.tensor([2])
|
||||
p = torch.tensor([0.9])
|
||||
generators = {0: torch.Generator()}
|
||||
generators[0].manual_seed(42)
|
||||
|
||||
sampler.forward_native(logits, generators, k, p)
|
||||
mock_npu_op.assert_called_once_with(logits, p, k)
|
||||
361
tests/ut/test_ascend_config.py
Normal file
361
tests/ut/test_ascend_config.py
Normal file
@@ -0,0 +1,361 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_config import (_check_torchair_supported,
|
||||
check_ascend_config,
|
||||
clear_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
|
||||
|
||||
class TestAscendConfig(TestBase):
|
||||
|
||||
@staticmethod
|
||||
def _clean_up_ascend_config(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
clear_ascend_config()
|
||||
func(*args, **kwargs)
|
||||
clear_ascend_config()
|
||||
|
||||
return wrapper
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_init_ascend_config_without_additional_config(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
# No additional config given, check the default value here.
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertIsNone(ascend_config.expert_map_path)
|
||||
|
||||
torchair_graph_config = ascend_config.torchair_graph_config
|
||||
self.assertFalse(torchair_graph_config.enabled)
|
||||
self.assertEqual(torchair_graph_config.mode, '')
|
||||
self.assertFalse(torchair_graph_config.use_cached_graph)
|
||||
self.assertEqual(torchair_graph_config.graph_batch_sizes, [])
|
||||
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
|
||||
self.assertFalse(torchair_graph_config.enable_multistream_mla)
|
||||
self.assertFalse(torchair_graph_config.enable_multistream_moe)
|
||||
self.assertTrue(torchair_graph_config.enable_view_optimize)
|
||||
self.assertFalse(torchair_graph_config.enable_kv_nz)
|
||||
|
||||
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||
self.assertFalse(ascend_scheduler_config.enabled)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_init_ascend_config_with_additional_config(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": True,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
"graph_batch_sizes_init": False,
|
||||
"enable_multistream_mla": True,
|
||||
"enable_multistream_moe": True,
|
||||
"enable_view_optimize": True,
|
||||
"enable_kv_nz": True
|
||||
},
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True
|
||||
},
|
||||
"expert_map_path": "test_expert_map_path",
|
||||
"refresh": True,
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertEqual(ascend_config.expert_map_path, "test_expert_map_path")
|
||||
|
||||
torchair_graph_config = ascend_config.torchair_graph_config
|
||||
self.assertTrue(torchair_graph_config.enabled)
|
||||
self.assertTrue(torchair_graph_config.use_cached_graph)
|
||||
self.assertEqual(torchair_graph_config.graph_batch_sizes, [1, 2, 4])
|
||||
self.assertFalse(torchair_graph_config.graph_batch_sizes_init)
|
||||
self.assertTrue(torchair_graph_config.enable_multistream_mla)
|
||||
self.assertTrue(torchair_graph_config.enable_multistream_moe)
|
||||
self.assertTrue(torchair_graph_config.enable_view_optimize)
|
||||
self.assertTrue(torchair_graph_config.enable_kv_nz)
|
||||
|
||||
ascend_scheduler_config = ascend_config.ascend_scheduler_config
|
||||
self.assertTrue(ascend_scheduler_config.enabled)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_init_ascend_config_with_refresh(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertFalse(ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertFalse(ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertTrue(ascend_config.torchair_graph_config.enabled)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_init_ascend_config_with_wrong_input(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"graph_batch_sizes": "fake_size",
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
with self.assertRaises(TypeError):
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes": [1, 2, 4, 8],
|
||||
"graph_batch_sizes_init": True,
|
||||
},
|
||||
"refresh": True,
|
||||
}
|
||||
with self.assertRaises(ValueError):
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_get_ascend_config(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertEqual(get_ascend_config(), ascend_config)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_get_ascend_config_without_init(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_ascend_config()
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_clear_ascend_config(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
ascend_config = init_ascend_config(test_vllm_config)
|
||||
self.assertEqual(get_ascend_config(), ascend_config)
|
||||
clear_ascend_config()
|
||||
with self.assertRaises(RuntimeError):
|
||||
get_ascend_config()
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_check_ascend_config_pass(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_check_ascend_config_wrong_case(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
|
||||
# torchair + eager mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
enforce_eager = True
|
||||
check_ascend_config(test_vllm_config, enforce_eager)
|
||||
# torchair + non deepseek model
|
||||
with self.assertRaises(NotImplementedError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
||||
fake_model_config = ModelConfig(model=model_path)
|
||||
fake_model_config.hf_config = PretrainedConfig()
|
||||
fake_model_config.hf_config.model_type = "llama"
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
# aclgraph + deepseek model
|
||||
with self.assertRaises(NotImplementedError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
||||
fake_model_config = ModelConfig(model=model_path)
|
||||
fake_model_config.hf_config = PretrainedConfig()
|
||||
fake_model_config.hf_config.model_type = "deepseek"
|
||||
test_vllm_config.model_config = fake_model_config
|
||||
init_ascend_config(test_vllm_config)
|
||||
check_ascend_config(test_vllm_config, False)
|
||||
|
||||
def test_check_torchair_supported(self):
|
||||
test_cases = [('deepseek_v3', True), ('PanguProMoE', True),
|
||||
('qwen', True), ('llama', False)]
|
||||
for model_type, expected_output in test_cases:
|
||||
self.assertEqual(_check_torchair_supported(model_type),
|
||||
expected_output)
|
||||
|
||||
@_clean_up_ascend_config
|
||||
def test_ascend_config_load_error(self):
|
||||
test_vllm_config = VllmConfig()
|
||||
# graph_batch_sizes should be list.
|
||||
with self.assertRaises(TypeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"graph_batch_sizes": "fake_size",
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_graph should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"use_cached_graph": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_kv_cache_bytes should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"use_cached_kv_cache_bytes": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# graph_batch_sizes should not be set without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# use_cached_kv_cache_bytes is valid only when torchair graph mode and use_cached_graph are enabled
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"use_cached_kv_cache_bytes": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# graph_batch_sizes_init should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"graph_batch_sizes_init": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# enable_multistream_mla should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"enable_multistream_mla": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# enable_multistream_moe should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"enable_multistream_moe": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# mode should not be configured without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"mode": 'max-autotune',
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
# enable_kv_nz should not be enabled without torchair graph mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
test_vllm_config.additional_config = {
|
||||
"torchair_graph_config": {
|
||||
"enabled": False,
|
||||
"enable_kv_nz": True,
|
||||
},
|
||||
"refresh": True
|
||||
}
|
||||
init_ascend_config(test_vllm_config)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
test_vllm_config.additional_config = {
|
||||
"lmhead_tensor_parallel_size": 2,
|
||||
"refresh": True
|
||||
}
|
||||
test_vllm_config.parallel_config = ParallelConfig(
|
||||
data_parallel_size=4, tensor_parallel_size=2)
|
||||
init_ascend_config(test_vllm_config)
|
||||
62
tests/ut/test_envs.py
Normal file
62
tests/ut/test_envs.py
Normal file
@@ -0,0 +1,62 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
|
||||
class TestEnvVariables(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.env_vars = list(envs_ascend.env_variables.keys())
|
||||
|
||||
def test_env_vars_behavior(self):
|
||||
for var_name in self.env_vars:
|
||||
with self.subTest(var=var_name):
|
||||
original_val = os.environ.get(var_name)
|
||||
var_handler = envs_ascend.env_variables[var_name]
|
||||
|
||||
try:
|
||||
if var_name in os.environ:
|
||||
del os.environ[var_name]
|
||||
self.assertEqual(getattr(envs_ascend, var_name),
|
||||
var_handler())
|
||||
|
||||
handler_source = inspect.getsource(var_handler)
|
||||
if 'int(' in handler_source:
|
||||
test_vals = ["123", "456"]
|
||||
elif 'bool(int(' in handler_source:
|
||||
test_vals = ["0", "1"]
|
||||
else:
|
||||
test_vals = [f"test_{var_name}", f"custom_{var_name}"]
|
||||
|
||||
for test_val in test_vals:
|
||||
os.environ[var_name] = test_val
|
||||
self.assertEqual(getattr(envs_ascend, var_name),
|
||||
var_handler())
|
||||
|
||||
finally:
|
||||
if original_val is None:
|
||||
os.environ.pop(var_name, None)
|
||||
else:
|
||||
os.environ[var_name] = original_val
|
||||
|
||||
def test_dir_and_getattr(self):
|
||||
self.assertEqual(sorted(envs_ascend.__dir__()), sorted(self.env_vars))
|
||||
for var_name in self.env_vars:
|
||||
with self.subTest(var=var_name):
|
||||
getattr(envs_ascend, var_name)
|
||||
714
tests/ut/test_platform.py
Normal file
714
tests/ut/test_platform.py
Normal file
@@ -0,0 +1,714 @@
|
||||
import importlib
|
||||
import unittest
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.distributed.distributed_c10d import PrefixStore
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.config.compilation import CUDAGraphMode
|
||||
from vllm.platforms import PlatformEnum
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
|
||||
|
||||
|
||||
class TestNPUPlatform(TestBase):
|
||||
|
||||
@staticmethod
|
||||
def mock_vllm_config():
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.compilation_config = MagicMock()
|
||||
mock_vllm_config.model_config = MagicMock()
|
||||
mock_vllm_config.parallel_config = MagicMock()
|
||||
mock_vllm_config.cache_config = MagicMock()
|
||||
mock_vllm_config.scheduler_config = MagicMock()
|
||||
mock_vllm_config.speculative_config = None
|
||||
mock_vllm_config.compilation_config.pass_config.enable_sequence_parallelism = False
|
||||
mock_vllm_config.compilation_config.cudagraph_mode = None
|
||||
return mock_vllm_config
|
||||
|
||||
@staticmethod
|
||||
def mock_vllm_ascend_config():
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.torchair_graph_config.enabled = False
|
||||
mock_ascend_config.ascend_scheduler_config.enabled = False
|
||||
return mock_ascend_config
|
||||
|
||||
def setUp(self):
|
||||
self.platform = NPUPlatform()
|
||||
|
||||
def test_class_variables(self):
|
||||
self.assertEqual(NPUPlatform._enum, PlatformEnum.OOT)
|
||||
self.assertEqual(NPUPlatform.device_name, "npu")
|
||||
self.assertEqual(NPUPlatform.device_type, "npu")
|
||||
self.assertEqual(NPUPlatform.simple_compile_backend, "eager")
|
||||
self.assertEqual(NPUPlatform.ray_device_key, "NPU")
|
||||
self.assertEqual(NPUPlatform.device_control_env_var,
|
||||
"ASCEND_RT_VISIBLE_DEVICES")
|
||||
self.assertEqual(NPUPlatform.dispatch_key, "PrivateUse1")
|
||||
self.assertEqual(NPUPlatform.supported_quantization,
|
||||
[ASCEND_QUANTIZATION_METHOD])
|
||||
|
||||
def test_is_sleep_mode_available(self):
|
||||
self.assertTrue(self.platform.is_sleep_mode_available())
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_parser(self, mock_quant_config,
|
||||
mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_action = MagicMock()
|
||||
mock_action.choices = ["awq", "gptq"]
|
||||
mock_parser._option_string_actions = {"--quantization": mock_action}
|
||||
|
||||
self.platform.pre_register_and_update(mock_parser)
|
||||
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
self.assertTrue(ASCEND_QUANTIZATION_METHOD in mock_action.choices)
|
||||
self.assertEqual(len(mock_action.choices), 3) # original 2 + ascend
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_without_parser(self, mock_quant_config,
|
||||
mock_adapt_patch):
|
||||
self.platform.pre_register_and_update(None)
|
||||
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_parser_no_quant_action(
|
||||
self, mock_quant_config, mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_parser._option_string_actions = {}
|
||||
|
||||
self.platform.pre_register_and_update(mock_parser)
|
||||
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
|
||||
@patch("vllm_ascend.utils.adapt_patch")
|
||||
@patch("vllm_ascend.quantization.quant_config.AscendQuantConfig")
|
||||
def test_pre_register_and_update_with_existing_ascend_quant(
|
||||
self, mock_quant_config, mock_adapt_patch):
|
||||
mock_parser = MagicMock()
|
||||
mock_action = MagicMock()
|
||||
mock_action.choices = ["awq", ASCEND_QUANTIZATION_METHOD]
|
||||
mock_parser._option_string_actions = {"--quantization": mock_action}
|
||||
|
||||
self.platform.pre_register_and_update(mock_parser)
|
||||
|
||||
mock_adapt_patch.assert_called_once_with(is_global_patch=True)
|
||||
self.assertEqual(len(mock_action.choices), 2)
|
||||
|
||||
def test_get_device_capability(self):
|
||||
self.assertIsNone(self.platform.get_device_capability(device_id=0))
|
||||
|
||||
@patch("torch.npu.get_device_name")
|
||||
def test_get_device_name(self, mock_get_device_name):
|
||||
device_id = 0
|
||||
device_name = "Ascend910B2"
|
||||
mock_get_device_name.return_value = device_name
|
||||
self.assertEqual(self.platform.get_device_name(device_id), device_name)
|
||||
mock_get_device_name.assert_called_once_with(0)
|
||||
|
||||
def test_is_async_output_supported(self):
|
||||
self.assertTrue(
|
||||
self.platform.is_async_output_supported(enforce_eager=None))
|
||||
self.assertTrue(
|
||||
self.platform.is_async_output_supported(enforce_eager=True))
|
||||
self.assertTrue(
|
||||
self.platform.is_async_output_supported(enforce_eager=False))
|
||||
|
||||
@patch("torch.inference_mode")
|
||||
def test_inference_mode(self, mock_inference_mode):
|
||||
mock_inference_mode.return_value = None
|
||||
self.assertIsNone(self.platform.inference_mode())
|
||||
mock_inference_mode.assert_called_once()
|
||||
|
||||
@patch("torch.npu.set_device")
|
||||
def test_set_device_normal(self, mock_set_device):
|
||||
device = torch.device("npu:0")
|
||||
self.platform.set_device(device)
|
||||
mock_set_device.assert_called_once_with(device)
|
||||
|
||||
@patch("torch.npu.set_device",
|
||||
side_effect=RuntimeError("Device not available"))
|
||||
def test_set_device_failure(self, mock_set_device):
|
||||
device = torch.device("npu:0")
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.set_device(device)
|
||||
mock_set_device.assert_called_once_with(device)
|
||||
|
||||
@patch("torch.npu.empty_cache")
|
||||
def test_empty_cache_normal(self, mock_empty_cache):
|
||||
self.platform.empty_cache()
|
||||
mock_empty_cache.assert_called_once()
|
||||
|
||||
@patch("torch.npu.empty_cache",
|
||||
side_effect=RuntimeError("Cache clearing failed"))
|
||||
def test_empty_cache_failure(self, mock_empty_cache):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.empty_cache()
|
||||
mock_empty_cache.assert_called_once()
|
||||
|
||||
@patch("torch.npu.synchronize")
|
||||
def test_synchronize_normal(self, mock_synchronize):
|
||||
self.platform.synchronize()
|
||||
mock_synchronize.assert_called_once()
|
||||
|
||||
@patch("torch.npu.synchronize",
|
||||
side_effect=RuntimeError("Synchronization failed"))
|
||||
def test_synchronize_failure(self, mock_synchronize):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.synchronize()
|
||||
mock_synchronize.assert_called_once()
|
||||
|
||||
@patch("torch.npu.mem_get_info")
|
||||
def test_mem_get_info_normal(self, mock_mem_get_info):
|
||||
free_memory_size = 1024
|
||||
total_memory_size = 2048
|
||||
memory_info = (free_memory_size, total_memory_size)
|
||||
mock_mem_get_info.return_value = memory_info
|
||||
result = self.platform.mem_get_info()
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertEqual(len(result), 2)
|
||||
self.assertEqual(result, memory_info)
|
||||
mock_mem_get_info.assert_called_once()
|
||||
|
||||
@patch("torch.npu.mem_get_info",
|
||||
side_effect=RuntimeError("NPU not available"))
|
||||
def test_mem_get_info_failure(self, mock_mem_get_info):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.mem_get_info()
|
||||
mock_mem_get_info.assert_called_once()
|
||||
|
||||
@patch("gc.collect")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
def test_clear_npu_memory_normal(self, mock_reset_stats, mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_called_once()
|
||||
mock_reset_stats.assert_called_once()
|
||||
|
||||
@patch("gc.collect", side_effect=Exception("GC failed"))
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
def test_clear_npu_memory_gc_collect_failure(self, mock_reset_stats,
|
||||
mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
with self.assertRaises(Exception):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_not_called()
|
||||
mock_reset_stats.assert_not_called()
|
||||
|
||||
@patch("gc.collect")
|
||||
@patch("torch.npu.empty_cache",
|
||||
side_effect=RuntimeError("Cache clear failed"))
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
def test_clear_npu_memory_empty_cache_failure(self, mock_reset_stats,
|
||||
mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_called_once()
|
||||
mock_reset_stats.assert_not_called()
|
||||
|
||||
@patch("gc.collect")
|
||||
@patch("torch.npu.empty_cache")
|
||||
@patch("torch.npu.reset_peak_memory_stats",
|
||||
side_effect=RuntimeError("Reset failed"))
|
||||
def test_clear_npu_memory_reset_stats_failure(self, mock_reset_stats,
|
||||
mock_empty_cache,
|
||||
mock_gc_collect):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.clear_npu_memory()
|
||||
|
||||
mock_gc_collect.assert_called_once()
|
||||
mock_empty_cache.assert_called_once()
|
||||
mock_reset_stats.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm_ascend.utils.update_aclgraph_sizes")
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("os.environ", {})
|
||||
def test_check_and_update_config_basic_config_update(
|
||||
self, mock_is_310p, mock_update_acl, mock_init_ascend,
|
||||
mock_check_ascend):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.enable_expert_parallel = False
|
||||
|
||||
# Use importlib.reload to reload the platform module, ensuring the mocked init_ascend_config method is used.
|
||||
# Without this reload, when calling self.platform.check_and_update_config,
|
||||
# it would execute the original unmocked init_ascend_config method, causing the unit test to fail.
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
|
||||
mock_init_ascend.assert_called_once_with(vllm_config)
|
||||
mock_check_ascend.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_no_model_config_warning(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config = None
|
||||
|
||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertTrue("Model config is missing" in cm.output[0])
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_enforce_eager_mode(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = True
|
||||
|
||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertTrue("Compilation disabled, using eager mode by default" in
|
||||
cm.output[0])
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.level,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_unsupported_compilation_level(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.level = CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertTrue("NPU does not support" in cm.output[0])
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.level,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Revert me when vllm support setting cudagraph_mode on oot platform")
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_unsupported_cudagraph_mode(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertTrue(
|
||||
"cudagraph_mode is not support on NPU. falling back to NONE" in
|
||||
cm.output[0])
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.level,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_disable_aclgraph_when_ray_enabled(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
vllm_config.parallel_config.distributed_executor_backend = "ray"
|
||||
|
||||
with self.assertLogs(logger="vllm", level="WARNING") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
print(30 * "=", f"cm.output: {cm.output}")
|
||||
self.assertTrue(
|
||||
"Ray distributed executor backend is not compatible with ACL Graph mode"
|
||||
in cm.output[0])
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.level,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_torchair_enabled_compilation(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
mock_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = mock_ascend_config
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.model_config.enforce_eager = False
|
||||
vllm_config.compilation_config.level = CompilationLevel.PIECEWISE
|
||||
|
||||
with self.assertLogs(logger="vllm", level="INFO") as cm:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertTrue("Torchair compilation enabled" in cm.output[0])
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.level,
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
)
|
||||
self.assertEqual(
|
||||
vllm_config.compilation_config.cudagraph_mode,
|
||||
CUDAGraphMode.NONE,
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_cache_config_block_size(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.cache_config.block_size = None
|
||||
vllm_config.cache_config.enable_prefix_caching = True
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
|
||||
self.assertEqual(vllm_config.cache_config.block_size, 128)
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_v1_worker_class_selection(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.parallel_config.worker_cls = "auto"
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
|
||||
self.assertEqual(
|
||||
vllm_config.parallel_config.worker_cls,
|
||||
"vllm_ascend.worker.worker_v1.NPUWorker",
|
||||
)
|
||||
|
||||
test_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
test_ascend_config.torchair_graph_config.enabled = True
|
||||
mock_init_ascend.return_value = test_ascend_config
|
||||
vllm_config.parallel_config.worker_cls = "auto"
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertEqual(
|
||||
vllm_config.parallel_config.worker_cls,
|
||||
"vllm_ascend.torchair.torchair_worker.NPUTorchairWorker",
|
||||
)
|
||||
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=True)
|
||||
def test_check_and_update_config_310p_no_custom_ops(
|
||||
self, mock_is_310p, mock_init_ascend, mock_check_ascend):
|
||||
mock_init_ascend.return_value = TestNPUPlatform.mock_vllm_ascend_config(
|
||||
)
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
vllm_config.compilation_config.custom_ops = []
|
||||
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
self.assertEqual(vllm_config.compilation_config.custom_ops, [])
|
||||
|
||||
@patch("vllm_ascend.utils.is_310p", return_value=False)
|
||||
@patch("vllm_ascend.ascend_config.check_ascend_config")
|
||||
@patch("vllm_ascend.ascend_config.init_ascend_config")
|
||||
def test_check_and_update_config_ascend_scheduler_config(
|
||||
self, mock_init_ascend, mock_check_ascend, mock_is_310p):
|
||||
mock_ascend_config = TestNPUPlatform.mock_vllm_ascend_config()
|
||||
mock_ascend_config.ascend_scheduler_config.enabled = True
|
||||
mock_init_ascend.return_value = mock_ascend_config
|
||||
|
||||
vllm_config = TestNPUPlatform.mock_vllm_config()
|
||||
|
||||
with patch("vllm_ascend.core.schedule_config.AscendSchedulerConfig"
|
||||
) as mock_scheduler:
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.check_and_update_config(vllm_config)
|
||||
mock_scheduler.initialize_from_config.assert_called_once()
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_mla(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=True,
|
||||
)
|
||||
self.assertEqual(result,
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend")
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_mla_and_torchair(
|
||||
self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend")
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_and_torchair(self,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend"
|
||||
)
|
||||
|
||||
@patch('vllm_ascend.platform.get_ascend_config')
|
||||
def test_get_attn_backend_cls_use_v1_only(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result = self.platform.get_attn_backend_cls(
|
||||
selected_backend="ascend",
|
||||
head_size=64,
|
||||
dtype="float16",
|
||||
kv_cache_dtype="float16",
|
||||
block_size=64,
|
||||
use_v1=True,
|
||||
use_mla=False,
|
||||
)
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend")
|
||||
|
||||
def test_get_punica_wrapper(self):
|
||||
result = self.platform.get_punica_wrapper()
|
||||
self.assertEqual(
|
||||
result,
|
||||
"vllm_ascend.lora.punica_wrapper.punica_npu.PunicaWrapperNPU")
|
||||
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.max_memory_allocated")
|
||||
def test_get_current_memory_usage_with_specific_device(
|
||||
self, mock_max_memory, mock_reset_stats):
|
||||
max_memory_allocated_result = 1024.0
|
||||
mock_max_memory.return_value = max_memory_allocated_result
|
||||
test_device = torch.device("npu:0")
|
||||
result = self.platform.get_current_memory_usage(device=test_device)
|
||||
|
||||
mock_reset_stats.assert_called_once_with(test_device)
|
||||
mock_max_memory.assert_called_once_with(test_device)
|
||||
self.assertEqual(result, max_memory_allocated_result)
|
||||
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch("torch.npu.max_memory_allocated")
|
||||
def test_get_current_memory_usage_with_default_device(
|
||||
self, mock_max_memory, mock_reset_stats):
|
||||
max_memory_allocated_result = 1024.0
|
||||
mock_max_memory.return_value = max_memory_allocated_result
|
||||
|
||||
result = self.platform.get_current_memory_usage()
|
||||
|
||||
mock_reset_stats.assert_called_once_with(None)
|
||||
mock_max_memory.assert_called_once_with(None)
|
||||
self.assertEqual(result, max_memory_allocated_result)
|
||||
|
||||
@patch("torch.npu.reset_peak_memory_stats",
|
||||
side_effect=RuntimeError("Device error"))
|
||||
@patch("torch.npu.max_memory_allocated")
|
||||
def test_get_current_memory_usage_when_reset_stats_fails(
|
||||
self, mock_max_memory, mock_reset_stats):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.get_current_memory_usage()
|
||||
mock_reset_stats.assert_called_once()
|
||||
mock_max_memory.assert_not_called()
|
||||
|
||||
@patch("torch.npu.reset_peak_memory_stats")
|
||||
@patch(
|
||||
"torch.npu.max_memory_allocated",
|
||||
side_effect=RuntimeError("Memory query failed"),
|
||||
)
|
||||
def test_get_current_memory_usage_when_query_fails(self, mock_max_memory,
|
||||
mock_reset_stats):
|
||||
with self.assertRaises(RuntimeError):
|
||||
self.platform.get_current_memory_usage()
|
||||
mock_reset_stats.assert_called_once()
|
||||
mock_max_memory.assert_called_once()
|
||||
|
||||
def test_get_device_communicator_cls_returns_correct_value(self):
|
||||
self.assertEqual(
|
||||
self.platform.get_device_communicator_cls(),
|
||||
"vllm_ascend.distributed.communicator.NPUCommunicator",
|
||||
)
|
||||
|
||||
def test_is_pin_memory_available_returns_true(self):
|
||||
self.assertTrue(self.platform.is_pin_memory_available())
|
||||
|
||||
def test_supports_v1(self):
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
mock_config = MagicMock(spec=ModelConfig)
|
||||
self.assertTrue(self.platform.supports_v1(mock_config))
|
||||
|
||||
def test_get_static_graph_wrapper_cls_returns_correct_value(self):
|
||||
self.assertEqual(
|
||||
self.platform.get_static_graph_wrapper_cls(),
|
||||
"vllm_ascend.compilation.acl_graph.ACLGraphWrapper",
|
||||
)
|
||||
|
||||
@patch("torch.distributed.is_hccl_available", return_value=True)
|
||||
@patch("torch_npu._C._distributed_c10d.ProcessGroupHCCL")
|
||||
@patch("torch.distributed.ProcessGroup")
|
||||
def test_successful_initialization(self, mock_pg, mock_pg_hccl, _):
|
||||
mock_prefix = MagicMock(spec=PrefixStore)
|
||||
mock_backend = MagicMock()
|
||||
mock_pg_hccl.return_value = mock_backend
|
||||
group_rank = 0
|
||||
group_size = 4
|
||||
|
||||
mock_pg_instance = MagicMock(spec=ProcessGroup)
|
||||
mock_pg.return_value = mock_pg_instance
|
||||
|
||||
# Use importlib.reload() to force-reload the platform module and ensure the mocked ProcessGroup is used.
|
||||
# Without this reload, when executing self.platform.stateless_init_device_torch_dist_pg(),
|
||||
# it would invoke the original unmocked ProcessGroup implementation instead of our test mock,
|
||||
# which would cause the unit test to fail.
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
|
||||
result = self.platform.stateless_init_device_torch_dist_pg(
|
||||
backend="hccl",
|
||||
prefix_store=mock_prefix,
|
||||
group_rank=group_rank,
|
||||
group_size=group_size,
|
||||
timeout=timedelta(seconds=30),
|
||||
)
|
||||
|
||||
mock_pg.assert_called_once_with(mock_prefix, group_rank, group_size)
|
||||
mock_pg_hccl.assert_called_once_with(mock_prefix, group_rank,
|
||||
group_size, unittest.mock.ANY)
|
||||
mock_backend._set_sequence_number_for_group.assert_called_once()
|
||||
mock_pg_instance._register_backend.assert_called_once_with(
|
||||
torch.device("npu"), unittest.mock.ANY, mock_backend)
|
||||
self.assertEqual(result, mock_pg_instance)
|
||||
|
||||
@patch("torch.distributed.is_hccl_available", return_value=False)
|
||||
def test_hccl_unavailable(self, _):
|
||||
with self.assertRaises(AssertionError):
|
||||
from vllm_ascend import platform
|
||||
|
||||
importlib.reload(platform)
|
||||
self.platform.stateless_init_device_torch_dist_pg(
|
||||
backend="hccl",
|
||||
prefix_store=MagicMock(),
|
||||
group_rank=0,
|
||||
group_size=4,
|
||||
timeout=timedelta(seconds=30),
|
||||
)
|
||||
351
tests/ut/test_utils.py
Normal file
351
tests/ut/test_utils.py
Normal file
@@ -0,0 +1,351 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import math
|
||||
import os
|
||||
from threading import Lock
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
from vllm.config import (CompilationConfig, ModelConfig, ParallelConfig,
|
||||
VllmConfig)
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend import utils
|
||||
|
||||
|
||||
class TestUtils(TestBase):
|
||||
|
||||
def test_is_310p(self):
|
||||
utils._IS_310P = None
|
||||
with mock.patch("vllm_ascend._build_info.__soc_version__",
|
||||
"Ascend310P3"):
|
||||
self.assertTrue(utils.is_310p())
|
||||
utils._IS_310P = None
|
||||
with mock.patch("vllm_ascend._build_info.__soc_version__",
|
||||
"Ascend910P1"):
|
||||
self.assertFalse(utils.is_310p())
|
||||
|
||||
def test_sleep_mode_enabled(self):
|
||||
utils._SLEEP_MODE_ENABLED = None
|
||||
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
|
||||
True):
|
||||
self.assertTrue(utils.sleep_mode_enabled())
|
||||
utils._SLEEP_MODE_ENABLED = None
|
||||
with mock.patch("vllm_ascend._build_info.__sleep_mode_enabled__",
|
||||
False):
|
||||
self.assertFalse(utils.sleep_mode_enabled())
|
||||
|
||||
def test_nd_to_nz_2d(self):
|
||||
# can be divided by 16
|
||||
input_tensor = torch.randn(32, 64)
|
||||
output = utils.nd_to_nz_2d(input_tensor)
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 64 // 16)
|
||||
self.assertEqual(output.shape[2], 32)
|
||||
self.assertEqual(output.shape[3], 16)
|
||||
|
||||
# cannot be divided by 16
|
||||
input_tensor = torch.randn(30, 62)
|
||||
output = utils.nd_to_nz_2d(input_tensor)
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], math.ceil(62 / 16))
|
||||
self.assertEqual(output.shape[2], 32)
|
||||
self.assertEqual(output.shape[3], 16)
|
||||
|
||||
# pad to 16
|
||||
input_tensor = torch.randn(8, 12)
|
||||
output = utils.nd_to_nz_2d(input_tensor)
|
||||
self.assertEqual(output.shape[0], 1)
|
||||
self.assertEqual(output.shape[1], 1) # 12->16, 16//16=1
|
||||
self.assertEqual(output.shape[2], 16) # 8->16
|
||||
self.assertEqual(output.shape[3], 16)
|
||||
|
||||
# check if the output is contiguous
|
||||
input_tensor = torch.randn(32, 64)
|
||||
output = utils.nd_to_nz_2d(input_tensor)
|
||||
self.assertTrue(output.is_contiguous())
|
||||
|
||||
# check if the output values are preserved
|
||||
input_tensor = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
|
||||
output = utils.nd_to_nz_2d(input_tensor)
|
||||
expected = torch.tensor(
|
||||
[[[[1, 2, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[5, 6, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]])
|
||||
self.assertTrue(torch.allclose(output, expected))
|
||||
|
||||
def test_aligned_16(self):
|
||||
# align to 16
|
||||
input_tensor = torch.randn(15, 64)
|
||||
output_tensor = utils.aligned_16(input_tensor)
|
||||
self.assertEqual(output_tensor.shape[0], 16)
|
||||
|
||||
# align to 16
|
||||
input_tensor = torch.randn(16, 64)
|
||||
output_tensor = utils.aligned_16(input_tensor)
|
||||
self.assertEqual(output_tensor.shape[0], 16)
|
||||
self.assertTrue(torch.equal(input_tensor, output_tensor))
|
||||
|
||||
# align to 32
|
||||
input_tensor = torch.randn(17, 64)
|
||||
output_tensor = utils.aligned_16(input_tensor)
|
||||
self.assertEqual(output_tensor.shape[0], 32)
|
||||
|
||||
@mock.patch('importlib.util.find_spec')
|
||||
@mock.patch('importlib.import_module')
|
||||
def test_try_register_lib(self, mock_import_module, mock_find_spec):
|
||||
# import OK
|
||||
mock_find_spec.return_value = mock.MagicMock()
|
||||
mock_import_module.return_value = mock.MagicMock()
|
||||
lib_name = "existing_lib"
|
||||
lib_info = "Library found and imported successfully"
|
||||
utils.try_register_lib(lib_name, lib_info)
|
||||
|
||||
# Can't find lib
|
||||
mock_find_spec.return_value = None
|
||||
lib_name = "non_existing_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
|
||||
# import error
|
||||
mock_find_spec.return_value = mock.MagicMock()
|
||||
mock_import_module.side_effect = ImportError("import error")
|
||||
lib_name = "error_lib"
|
||||
utils.try_register_lib(lib_name)
|
||||
|
||||
def test_enable_custom_op(self):
|
||||
result = utils.enable_custom_op()
|
||||
self.assertTrue(result)
|
||||
|
||||
utils._CUSTOM_OP_ENABLED = None
|
||||
|
||||
with mock.patch('builtins.__import__') as mock_import_module:
|
||||
mock_import_module.side_effect = ImportError("import error")
|
||||
self.assertFalse(utils.enable_custom_op())
|
||||
|
||||
def test_find_hccl_library(self):
|
||||
with mock.patch.dict(os.environ,
|
||||
{"HCCL_SO_PATH": "/path/to/hccl/libhccl.so"}):
|
||||
self.assertEqual(utils.find_hccl_library(),
|
||||
"/path/to/hccl/libhccl.so")
|
||||
with mock.patch("torch.version.cann", None):
|
||||
self.assertRaises(ValueError, utils.find_hccl_library)
|
||||
with mock.patch("torch.version.cann", "Ascend910"):
|
||||
self.assertEqual(utils.find_hccl_library(), "libhccl.so")
|
||||
|
||||
def test_current_stream(self):
|
||||
with mock.patch("torch.npu.current_stream") as mock_current_stream:
|
||||
self.assertEqual(utils.current_stream(), mock_current_stream())
|
||||
|
||||
def test_vllm_version_is(self):
|
||||
with mock.patch.dict(os.environ, {"VLLM_VERSION": "1.0.0"}):
|
||||
with mock.patch("vllm.__version__", "1.0.0"):
|
||||
self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0"))
|
||||
self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0"))
|
||||
with mock.patch("vllm.__version__", "2.0.0"):
|
||||
self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0"))
|
||||
self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0"))
|
||||
with mock.patch("vllm.__version__", "1.0.0"):
|
||||
self.assertTrue(utils.vllm_version_is.__wrapped__("1.0.0"))
|
||||
self.assertFalse(utils.vllm_version_is.__wrapped__("2.0.0"))
|
||||
with mock.patch("vllm.__version__", "2.0.0"):
|
||||
self.assertTrue(utils.vllm_version_is.__wrapped__("2.0.0"))
|
||||
self.assertFalse(utils.vllm_version_is.__wrapped__("1.0.0"))
|
||||
# Test caching takes effect
|
||||
utils.vllm_version_is.cache_clear()
|
||||
utils.vllm_version_is("1.0.0")
|
||||
misses = utils.vllm_version_is.cache_info().misses
|
||||
hits = utils.vllm_version_is.cache_info().hits
|
||||
self.assertEqual(misses, 1)
|
||||
self.assertEqual(hits, 0)
|
||||
utils.vllm_version_is("1.0.0")
|
||||
hits = utils.vllm_version_is.cache_info().hits
|
||||
self.assertEqual(hits, 1)
|
||||
|
||||
def test_get_max_hidden_layers(self):
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
class SimpleConfig(PretrainedConfig):
|
||||
|
||||
def __init__(self, num_hidden_layers=12):
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
def to_dict(self):
|
||||
return {"num_hidden_layers": self.num_hidden_layers}
|
||||
|
||||
self.assertEqual(utils.get_max_hidden_layers(SimpleConfig()), 12)
|
||||
self.assertEqual(utils.get_max_hidden_layers(SimpleConfig(24)), 24)
|
||||
|
||||
class NestedConfig(PretrainedConfig):
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"model": {
|
||||
"encoder": {
|
||||
"num_hidden_layers": 8
|
||||
},
|
||||
"decoder": {
|
||||
"num_hidden_layers": 12
|
||||
}
|
||||
},
|
||||
"other_setting": True
|
||||
}
|
||||
|
||||
self.assertEqual(utils.get_max_hidden_layers(NestedConfig()), 12)
|
||||
|
||||
class MultiValueConfig(PretrainedConfig):
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"num_hidden_layers": 6,
|
||||
"submodule": {
|
||||
"num_hidden_layers": 18,
|
||||
"subsub": {
|
||||
"num_hidden_layers": 9
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.assertEqual(utils.get_max_hidden_layers(MultiValueConfig()), 18)
|
||||
|
||||
class NoLayerConfig(PretrainedConfig):
|
||||
|
||||
def to_dict(self):
|
||||
return {"attention_heads": 8}
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
utils.get_max_hidden_layers(NoLayerConfig())
|
||||
self.assertIn("num_hidden_layers", str(context.exception))
|
||||
|
||||
def test_update_aclgraph_sizes(self):
|
||||
# max_num_batch_sizes < len(original_sizes)
|
||||
test_compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[i for i in range(150)])
|
||||
model_path = os.path.join(os.path.dirname(__file__), "fake_weight")
|
||||
test_model_config = ModelConfig(model=model_path, enforce_eager=True)
|
||||
test_parallel_config = ParallelConfig()
|
||||
test_vllm_config = VllmConfig(
|
||||
model_config=test_model_config,
|
||||
compilation_config=test_compilation_config,
|
||||
parallel_config=test_parallel_config,
|
||||
)
|
||||
utils.update_aclgraph_sizes(test_vllm_config)
|
||||
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
|
||||
utils.update_aclgraph_sizes(test_vllm_config)
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
self.assertEqual(
|
||||
147,
|
||||
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
# max_num_batch_sizes >= len(original_sizes)
|
||||
test_compilation_config = CompilationConfig(
|
||||
cudagraph_capture_sizes=[1, 2, 3])
|
||||
test_vllm_config = VllmConfig(
|
||||
model_config=test_model_config,
|
||||
compilation_config=test_compilation_config,
|
||||
parallel_config=test_parallel_config,
|
||||
)
|
||||
utils.update_aclgraph_sizes(test_vllm_config)
|
||||
os.environ['HCCL_OP_EXPANSION_MODE'] = 'AIV'
|
||||
utils.update_aclgraph_sizes(test_vllm_config)
|
||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||
self.assertEqual(
|
||||
3,
|
||||
len(test_vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
@mock.patch("vllm.model_executor.custom_op.CustomOp")
|
||||
@mock.patch("vllm_ascend.ops.activation.AscendQuickGELU")
|
||||
@mock.patch("vllm_ascend.ops.activation.AscendSiluAndMul")
|
||||
@mock.patch("vllm_ascend.ops.layernorm.AscendRMSNorm")
|
||||
def test_register_ascend_customop(self, mock_ascend_rmsnorm,
|
||||
mock_ascend_silu_and_mul,
|
||||
mock_ascend_quick_gelu, mock_customop):
|
||||
utils._ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||
|
||||
# ascend custom op is not registered
|
||||
utils.register_ascend_customop()
|
||||
# should call register_oot three
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
||||
self.assertTrue(utils._ASCEND_CUSTOMOP_IS_REIGISTERED)
|
||||
|
||||
# ascend custom op is already registered
|
||||
utils.register_ascend_customop()
|
||||
# should not register_oot again, thus only called three in this ut
|
||||
self.assertEqual(mock_customop.register_oot.call_count, 12)
|
||||
|
||||
|
||||
class TestProfileExecuteDuration(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
utils.ProfileExecuteDuration._instance = None
|
||||
utils.ProfileExecuteDuration._observations = []
|
||||
utils.ProfileExecuteDuration._lock = Lock()
|
||||
|
||||
def test_singleton_creation(self):
|
||||
instance1 = utils.ProfileExecuteDuration()
|
||||
self.assertIsNotNone(instance1)
|
||||
self.assertIs(instance1, utils.ProfileExecuteDuration._instance)
|
||||
|
||||
instance2 = utils.ProfileExecuteDuration()
|
||||
self.assertIs(instance1, instance2)
|
||||
|
||||
def test_thread_safety(self):
|
||||
from threading import Thread
|
||||
|
||||
instances = []
|
||||
|
||||
def create_instance():
|
||||
instances.append(utils.ProfileExecuteDuration())
|
||||
|
||||
threads = [Thread(target=create_instance) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
first_instance = instances[0]
|
||||
for instance in instances[1:]:
|
||||
self.assertIs(first_instance, instance)
|
||||
|
||||
def test_atexit_registration(self):
|
||||
with mock.patch('atexit.register') as mock_register:
|
||||
instance = utils.ProfileExecuteDuration()
|
||||
mock_register.assert_called_once_with(instance.destroy)
|
||||
|
||||
def test_lock_usage(self):
|
||||
original_lock = utils.ProfileExecuteDuration._lock
|
||||
|
||||
with mock.patch.object(utils.ProfileExecuteDuration,
|
||||
'_lock',
|
||||
wraps=original_lock) as mock_lock:
|
||||
utils.ProfileExecuteDuration()
|
||||
mock_lock.__enter__.assert_called()
|
||||
mock_lock.__exit__.assert_called()
|
||||
|
||||
def test_observations_initialization(self):
|
||||
instance = utils.ProfileExecuteDuration()
|
||||
self.assertEqual(instance._observations, [])
|
||||
0
tests/ut/torchair/__init__.py
Normal file
0
tests/ut/torchair/__init__.py
Normal file
195
tests/ut/torchair/models/test_torchair_deepseek_mtp.py
Normal file
195
tests/ut/torchair/models/test_torchair_deepseek_mtp.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import pytest
|
||||
import torch
|
||||
from pytest_mock import MockerFixture
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import (
|
||||
TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor,
|
||||
TorchairDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp_layer(self, mocker: MockerFixture):
|
||||
config = PretrainedConfig(vocab_size=1000,
|
||||
hidden_size=768,
|
||||
rms_norm_eps=1e-5)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__",
|
||||
return_value=None)
|
||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None)
|
||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
||||
return mtp_layer
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
|
||||
mtp_layer = setup_mtp_layer
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch.object(mtp_layer,
|
||||
'eh_proj',
|
||||
return_value=torch.randn(2, 3, 768))
|
||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
||||
torch.randn(2, 3, 768))
|
||||
|
||||
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
||||
kv_cache = torch.randn(2, 3, 768)
|
||||
previous_hidden_states = torch.randn(2, 3, 768)
|
||||
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = mtp_layer(input_ids, positions, kv_cache, None,
|
||||
previous_hidden_states, inputs_embeds, 0)
|
||||
assert output.shape == (2, 3, 768)
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMultiTokenPredictor(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_predictor(self, mocker: MockerFixture):
|
||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
||||
mock_model_config = mocker.MagicMock(spec=ModelConfig)
|
||||
mock_hf_config = mocker.MagicMock()
|
||||
mock_hf_config.num_hidden_layers = 12
|
||||
mock_hf_config.num_nextn_predict_layers = 3
|
||||
mock_hf_config.vocab_size = 30000
|
||||
mock_model_config.hf_config = mock_hf_config
|
||||
mock_vllm_config.model_config = mock_model_config
|
||||
mock_vllm_config.cache_config = CacheConfig()
|
||||
mock_vllm_config.quant_config = mocker.MagicMock()
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
predictor = TorchairDeepSeekMultiTokenPredictor(
|
||||
vllm_config=mock_vllm_config)
|
||||
return predictor
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_predictor):
|
||||
predictor = setup_predictor
|
||||
assert predictor.num_mtp_layers == 3
|
||||
assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'kv_caches, inputs_embeds',
|
||||
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
|
||||
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
|
||||
inputs_embeds):
|
||||
predictor = setup_predictor
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
|
||||
# todo: need or not?
|
||||
# predictor.num_mtp_layers = 1
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
output = predictor.forward(input_ids, positions, kv_caches, None, None,
|
||||
inputs_embeds, 0)
|
||||
mock_layer.assert_called_once()
|
||||
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
|
||||
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
||||
predictor = setup_predictor
|
||||
|
||||
mock_layer = mocker.MagicMock()
|
||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
predictor.layers_list = [mock_layer]
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
|
||||
return_value=None)
|
||||
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
||||
|
||||
result_logits = predictor.compute_logits(hidden_states=hidden_states,
|
||||
sampling_metadata=None)
|
||||
predictor.logits_processor.assert_called_once()
|
||||
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
||||
|
||||
|
||||
class TestTorchairDeepSeekMTP(PytestBase):
|
||||
|
||||
@pytest.fixture
|
||||
def setup_mtp(self, mocker: MockerFixture):
|
||||
vllm_config = mocker.MagicMock()
|
||||
vllm_config.model_config.hf_config.num_hidden_layers = 12
|
||||
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
|
||||
vllm_config.cache_config = mocker.MagicMock()
|
||||
vllm_config.quant_config = mocker.MagicMock()
|
||||
|
||||
mocker.patch("torch.nn.Module.__setattr__")
|
||||
mocker.patch("torch.nn.Module.__getattr__")
|
||||
mocker.patch("torch.nn.Module.__delattr__")
|
||||
mocker.patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm.model_executor.layers.sampler.get_sampler",
|
||||
return_value=None)
|
||||
mocker.patch(
|
||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
||||
return_value=None)
|
||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
||||
return_value=mocker.Mock())
|
||||
|
||||
mtp = TorchairDeepSeekMTP(vllm_config=vllm_config)
|
||||
return mtp
|
||||
|
||||
def test_init(self, mocker: MockerFixture, setup_mtp):
|
||||
mtp = setup_mtp
|
||||
assert isinstance(mtp, TorchairDeepSeekMTP)
|
||||
|
||||
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
||||
input_ids = torch.tensor([[1, 2, 3]])
|
||||
positions = torch.tensor([[0, 1, 2]])
|
||||
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
||||
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
|
||||
spec_step_idx = 0
|
||||
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
|
||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
||||
previous_hidden_states, inputs_embeds,
|
||||
spec_step_idx)
|
||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
325
tests/ut/torchair/models/test_torchair_deepseek_v2.py
Normal file
325
tests/ut/torchair/models/test_torchair_deepseek_v2.py
Normal file
@@ -0,0 +1,325 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
from vllm_ascend.torchair.models.torchair_deepseek_v2 import (
|
||||
TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM,
|
||||
TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention,
|
||||
TorchairDeepseekV2MLP, TorchairDeepseekV2MoE,
|
||||
TorchairDeepseekV2RowParallelLinear,
|
||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config():
|
||||
config = PretrainedConfig(
|
||||
hidden_size=128,
|
||||
num_attention_heads=8,
|
||||
num_hidden_layers=2,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
rms_norm_eps=1e-6,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=2048,
|
||||
n_routed_experts=4,
|
||||
n_shared_experts=1,
|
||||
moe_intermediate_size=256,
|
||||
num_experts_per_tok=2,
|
||||
routed_scaling_factor=1.0,
|
||||
first_k_dense_replace=0,
|
||||
moe_layer_freq=1,
|
||||
kv_lora_rank=16,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
topk_method="noaux_tc",
|
||||
scoring_func="softmax",
|
||||
norm_topk_prob=True,
|
||||
n_group=1,
|
||||
topk_group=1,
|
||||
vocab_size=10000,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_config(base_config):
|
||||
model_config = SimpleNamespace(
|
||||
hf_config=base_config,
|
||||
tensor_parallel_size=1,
|
||||
dtype=torch.float32,
|
||||
use_mla=False,
|
||||
quant_config=None,
|
||||
max_model_len=2048,
|
||||
)
|
||||
|
||||
cache_config = CacheConfig()
|
||||
vllm_config = Mock()
|
||||
vllm_config.model_config = model_config
|
||||
vllm_config.cache_config = cache_config
|
||||
vllm_config.quant_config = None
|
||||
return vllm_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_distributed():
|
||||
tp_group = Mock(spec=GroupCoordinator)
|
||||
tp_group.rank_in_group = 0
|
||||
tp_group.world_size = 1
|
||||
tp_group.device_group = Mock()
|
||||
|
||||
dp_group = Mock(spec=GroupCoordinator)
|
||||
dp_group.rank_in_group = 0
|
||||
dp_group.world_size = 1
|
||||
|
||||
ep_group = Mock(spec=GroupCoordinator)
|
||||
ep_group.rank_in_group = 0
|
||||
ep_group.world_size = 1
|
||||
|
||||
pp_group = Mock(spec=GroupCoordinator)
|
||||
pp_group.rank_in_group = 0
|
||||
pp_group.world_size = 1
|
||||
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.scheduler_config = Mock(max_num_seqs=256)
|
||||
mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None)
|
||||
|
||||
with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \
|
||||
patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group",
|
||||
return_value=Mock(is_first_rank=False, is_last_rank=False)), \
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \
|
||||
patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group,
|
||||
_PP=pp_group), \
|
||||
patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_forward_context():
|
||||
forward_context = Mock(in_profile_run=False, with_prefill=False)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context",
|
||||
return_value=forward_context):
|
||||
yield
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_silu_and_mul():
|
||||
torch.set_default_device("cpu")
|
||||
|
||||
silu = TorchairDeepseekV2SiluAndMul()
|
||||
assert silu.weight_scale is None
|
||||
|
||||
x = torch.randn(2, 4)
|
||||
output = silu.forward_oot(x)
|
||||
assert output.shape == (2, 2)
|
||||
|
||||
weight_scale = Mock(return_value=torch.tensor(0.1))
|
||||
silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale)
|
||||
quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32)
|
||||
dynamic_scale = torch.randn(2, 1)
|
||||
with patch("torch_npu.npu_dequant_swiglu_quant",
|
||||
return_value=torch.randn(2, 4)):
|
||||
output = silu.forward_oot((quant_x, dynamic_scale))
|
||||
assert output.shape == (2, 4)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed):
|
||||
linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128,
|
||||
output_sizes=[64, 64],
|
||||
bias=False,
|
||||
quant_config=None)
|
||||
assert linear.output_sizes == [64, 64]
|
||||
|
||||
param = Mock()
|
||||
param.data = torch.zeros(128, 128)
|
||||
param.output_dim = 1
|
||||
param.is_gguf_weight = False
|
||||
param.is_gguf_weight_type = False
|
||||
loaded_weight = torch.randn(128, 64)
|
||||
linear.weight_loader(param, loaded_weight, loaded_shard_id=0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("cls", [
|
||||
TorchairDeepseekV2RowParallelLinearReplaceAllreduce,
|
||||
TorchairDeepseekV2RowParallelLinear
|
||||
])
|
||||
def test_row_parallel_linear(cls, mock_distributed):
|
||||
linear = cls(input_size=128, output_size=64, bias=False, quant_config=None)
|
||||
linear.quant_method = Mock()
|
||||
linear.quant_method.apply.return_value = torch.randn(2, 4, 64)
|
||||
|
||||
input_ = torch.randn(2, 4, 128)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim",
|
||||
return_value=[torch.randn(2, 4, 64)]):
|
||||
linear.input_is_parallel = False
|
||||
output = linear(input_, is_prefill=True)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
linear.input_is_parallel = True
|
||||
output = linear(input_, is_prefill=False)
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_mlp(mock_distributed, base_config):
|
||||
mlp = TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=None)
|
||||
assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
output = mlp(x)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
with patch(
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig"
|
||||
) as mock_quant_config:
|
||||
mock_quant_config.name = "w8a8dynamic"
|
||||
with pytest.raises(NotImplementedError):
|
||||
TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="silu",
|
||||
quant_config=mock_quant_config,
|
||||
force_replicate=False)
|
||||
with pytest.raises(ValueError):
|
||||
TorchairDeepseekV2MLP(hidden_size=128,
|
||||
intermediate_size=256,
|
||||
hidden_act="relu",
|
||||
quant_config=None)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_moe(mock_distributed, base_config,
|
||||
mock_forward_context):
|
||||
base_config.n_shared_experts = 1
|
||||
moe = TorchairDeepseekV2MoE(config=base_config,
|
||||
quant_config=None,
|
||||
prefix="mlp")
|
||||
assert moe.top_k == 2
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
attn_metadata = Mock(num_prefills=1)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.TorchairAscendFusedMoE.__call__",
|
||||
return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))):
|
||||
output = moe(x, attn_metadata)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed,
|
||||
base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
|
||||
attn = TorchairDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=16,
|
||||
kv_lora_rank=16,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None,
|
||||
prefix="layers.0.self_attn")
|
||||
assert attn.debug_layer_idx == 0
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(attn.mla_attn,
|
||||
"__call__",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
with pytest.raises(AssertionError):
|
||||
attn(positions, x)
|
||||
|
||||
attn = TorchairDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
num_heads=8,
|
||||
qk_nope_head_dim=16,
|
||||
qk_rope_head_dim=16,
|
||||
v_head_dim=32,
|
||||
q_lora_rank=None,
|
||||
kv_lora_rank=16,
|
||||
prefix="layers.1.self_attn")
|
||||
assert hasattr(attn, "q_proj")
|
||||
|
||||
|
||||
@patch("torch_npu.npu_add_rms_norm")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm,
|
||||
mock_distributed, base_config,
|
||||
vllm_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128),
|
||||
torch.randn(2, 128))
|
||||
base_config.n_routed_experts = 4
|
||||
layer = TorchairDeepseekV2DecoderLayer(
|
||||
config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=CacheConfig(),
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MoE)
|
||||
|
||||
x = torch.randn(2, 4, 128)
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
|
||||
with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \
|
||||
patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))):
|
||||
hidden_states, residual = layer(positions, x, None)
|
||||
assert hidden_states.shape == (2, 4, 128)
|
||||
|
||||
base_config.n_routed_experts = None
|
||||
layer = TorchairDeepseekV2DecoderLayer(
|
||||
config=base_config,
|
||||
prefix="layers.0",
|
||||
model_config=vllm_config.model_config,
|
||||
quant_config=None)
|
||||
assert isinstance(layer.mlp, TorchairDeepseekV2MLP)
|
||||
|
||||
|
||||
def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config):
|
||||
model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config)
|
||||
|
||||
input_ids = torch.randint(0, 10000, (2, 4))
|
||||
positions = torch.arange(4).repeat(2, 1)
|
||||
with patch.object(model.model,
|
||||
"forward",
|
||||
return_value=torch.randn(2, 4, 128)):
|
||||
output = model(input_ids, positions)
|
||||
assert output.shape == (2, 4, 128)
|
||||
|
||||
weights = [("model.embed_tokens.weight", torch.randn(10000, 128))]
|
||||
with patch(
|
||||
"vllm.model_executor.model_loader.weight_utils.default_weight_loader"
|
||||
):
|
||||
loaded = model.load_weights(weights)
|
||||
assert loaded is not None
|
||||
410
tests/ut/torchair/ops/test_torchair_fused_moe.py
Normal file
410
tests/ut/torchair/ops/test_torchair_fused_moe.py
Normal file
@@ -0,0 +1,410 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
|
||||
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
|
||||
from vllm_ascend.quantization.quantizer import W8A8Quantizer
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import (
|
||||
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
|
||||
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
|
||||
|
||||
adapt_patch(True)
|
||||
|
||||
|
||||
def mock_ep_and_mc2_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.rank = 0
|
||||
mock_group.world_size = 4
|
||||
mock_group.device_group = "mock_group_ep"
|
||||
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
|
||||
return mock_group
|
||||
|
||||
|
||||
def mock_dp_and_tp_group(mocker):
|
||||
mock_group = mocker.MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 2
|
||||
mock_group.device_group = "mock_group"
|
||||
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
|
||||
return mock_group
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dist_env(mocker: MockerFixture):
|
||||
# init dist env patch
|
||||
|
||||
with patch('torch.distributed.get_rank', return_value=0), \
|
||||
patch('torch.distributed.get_world_size', return_value=4), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
|
||||
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.tensor_model_parallel_all_reduce',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.data_parallel_reduce_scatter',
|
||||
return_value=torch.randn(5, 32)), \
|
||||
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
||||
return_value=mock_dp_and_tp_group(mocker)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_config',
|
||||
return_value=MagicMock(
|
||||
torchair_graph_config=MagicMock(enabled=False, enable_multistream_moe=False),
|
||||
expert_map_path=None
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.determine_expert_map',
|
||||
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context',
|
||||
return_value=MagicMock(
|
||||
max_tokens_across_dp=10,
|
||||
dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10])
|
||||
)), \
|
||||
patch('vllm_ascend.torchair.ops.torchair_fused_moe.get_current_vllm_config',
|
||||
return_value=MagicMock(
|
||||
parallel_config=MagicMock(tensor_parallel_size=2),
|
||||
scheduler_config=MagicMock(max_num_seqs=4),
|
||||
model_config=MagicMock(max_model_len=2048)
|
||||
)):
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_moe_env(mocker: MockerFixture):
|
||||
# init moe env patch
|
||||
|
||||
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
None
|
||||
)), \
|
||||
patch('torch_npu.npu_moe_init_routing', return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
|
||||
torch.randn(8, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_distribute_combine", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_grouped_matmul", return_value=(
|
||||
[torch.randn(16, 2)]
|
||||
)), \
|
||||
patch("torch_npu.npu_swiglu", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
|
||||
torch.randn(8, 2),
|
||||
torch.randint(0, 8, (8, 2)),
|
||||
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
||||
)), \
|
||||
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
||||
torch.randn(16, 2)
|
||||
)):
|
||||
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
||||
torch.randn(16, 2))), \
|
||||
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
||||
torch.randn(16, 2))):
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_moe_config():
|
||||
"""default moe config"""
|
||||
return {
|
||||
'num_experts': 8,
|
||||
'top_k': 2,
|
||||
'hidden_size': 512,
|
||||
'intermediate_size': 1024
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def moe_method(mock_dist_env):
|
||||
moe = MagicMock()
|
||||
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
||||
return TorchairAscendUnquantizedFusedMoEMethod(moe)
|
||||
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
|
||||
|
||||
class MockQuantMethod(nn.Module):
|
||||
|
||||
def __init__(self, shared_experts, num_tokens):
|
||||
super().__init__()
|
||||
if shared_experts:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
|
||||
torch.randn(num_tokens, 10)))
|
||||
else:
|
||||
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
||||
|
||||
|
||||
class MockFusedMoEMethod(FusedMoEMethodBase):
|
||||
moe = MagicMock()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(self.moe)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
pass
|
||||
|
||||
def apply(self, hidden_states: torch.Tensor,
|
||||
expert_weights: torch.Tensor) -> torch.Tensor:
|
||||
pass
|
||||
|
||||
|
||||
class TestTorchairAscendFusedMoe:
|
||||
|
||||
def test_init_no_quant(self, mock_dist_env, default_moe_config):
|
||||
layer = TorchairAscendFusedMoE(**default_moe_config)
|
||||
|
||||
layer.w13_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['intermediate_size'] * 2,
|
||||
default_moe_config['hidden_size']))
|
||||
layer.w2_weight = nn.Parameter(
|
||||
torch.randn(default_moe_config['num_experts'],
|
||||
default_moe_config['hidden_size'],
|
||||
default_moe_config['intermediate_size']))
|
||||
|
||||
assert layer.num_experts == default_moe_config['num_experts']
|
||||
assert layer.top_k == default_moe_config['top_k']
|
||||
assert hasattr(layer, 'w13_weight')
|
||||
assert hasattr(layer, 'w2_weight')
|
||||
|
||||
# check group_topk
|
||||
with pytest.raises(AssertionError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['use_grouped_topk'] = True
|
||||
layer = TorchairAscendFusedMoE(**error_config)
|
||||
|
||||
# check scoring_func
|
||||
with pytest.raises(ValueError):
|
||||
error_config = default_moe_config.copy()
|
||||
error_config['scoring_func'] = "random"
|
||||
layer = TorchairAscendFusedMoE(**error_config)
|
||||
|
||||
def test_init_with_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
mock_quant_config.is_layer_skipped_ascend.return_value = False
|
||||
with patch(
|
||||
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
|
||||
return_value=W8A8Quantizer):
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
|
||||
|
||||
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
|
||||
mock_quant_config = MagicMock()
|
||||
mock_quant_method = MockFusedMoEMethod()
|
||||
mock_quant_config.get_quant_method.return_value = mock_quant_method
|
||||
mock_quant_config.is_layer_skipped_ascend.return_value = True
|
||||
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config,
|
||||
quant_config=mock_quant_config)
|
||||
|
||||
assert moe.quant_method is not None
|
||||
assert isinstance(moe.quant_method,
|
||||
TorchairAscendUnquantizedFusedMoEMethod)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"others_param",
|
||||
[[None,
|
||||
MagicMock(return_value=torch.randn(5, 32)), False, 5, None],
|
||||
[2, None, False, 5, None], [None, None, True, 5, None],
|
||||
[None, None, False, 1, None], [None, None, True, 5, 1],
|
||||
[None, None, False, 5, 1]])
|
||||
def test_forward(self, mock_dist_env, default_moe_config, others_param):
|
||||
"""
|
||||
1 test has shared_experts
|
||||
2 test has top_k
|
||||
3 test is_prefill is true
|
||||
4 test single num_tokens(decode)
|
||||
5 test ep_size is 1 and is_prefill is true
|
||||
6 test ep_size is 1 and is_prefill is False
|
||||
"""
|
||||
top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param
|
||||
inputs = torch.randn(num_tokens, 32)
|
||||
router_logits = torch.randn(num_tokens, 8)
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config)
|
||||
|
||||
if ep_size == 1:
|
||||
moe.moe_parallel_config.ep_size = 1
|
||||
|
||||
moe.quant_method = MockQuantMethod(shared_experts, num_tokens)
|
||||
forward_context = MagicMock(mc2_mask=torch.zeros(num_tokens,
|
||||
dtype=torch.bool),
|
||||
padded_num_tokens=num_tokens)
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
output = moe.forward(inputs,
|
||||
router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=top_k,
|
||||
shared_experts=shared_experts)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
if shared_experts:
|
||||
assert output[0].shape == (num_tokens, 32)
|
||||
assert output[1].shape == (num_tokens, 10)
|
||||
else:
|
||||
assert output.shape == (num_tokens, 32)
|
||||
|
||||
def test_forward_ms_fused_moe_comp(self, mock_dist_env,
|
||||
default_moe_config):
|
||||
inputs = torch.randn(5, 32)
|
||||
router_logits = torch.randn(5, 8)
|
||||
moe = TorchairAscendFusedMoE(**default_moe_config)
|
||||
|
||||
moe.quant_method = MockQuantMethod(None, 5)
|
||||
output = moe._forward_ms_fused_moe_comp(inputs,
|
||||
router_logits,
|
||||
is_prefill=False,
|
||||
real_top_k=1)
|
||||
|
||||
moe.quant_method.apply.assert_called_once()
|
||||
|
||||
assert output.shape == (5, 32)
|
||||
|
||||
|
||||
class TestTorchairAscendUnquantizedFusedMoEMethod:
|
||||
|
||||
def test_process_weights_after_loading(self, moe_method, mock_dist_env):
|
||||
layer = MagicMock()
|
||||
layer.w13_weight.data = torch.randn(16, 32)
|
||||
layer.w2_weight.data = torch.randn(16, 32)
|
||||
|
||||
moe_method.process_weights_after_loading(layer)
|
||||
|
||||
assert isinstance(layer.w13_weight, torch.nn.Parameter)
|
||||
assert isinstance(layer.w2_weight, torch.nn.Parameter)
|
||||
assert not layer.w13_weight.requires_grad
|
||||
assert not layer.w2_weight.requires_grad
|
||||
|
||||
@pytest.mark.parametrize("others_param",
|
||||
[[256, 4], [128, 1], [128, 1], [128, 4]])
|
||||
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all
|
||||
2 test use_select_experts and fused_experts
|
||||
3 test use select_gating_topk_softmax_experts and fused_experts
|
||||
4 test use select_experts and fused_experts_with_all2all_buffer
|
||||
"""
|
||||
global_num_experts, ep_size = others_param
|
||||
is_prefill = False
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
forward_context = MagicMock(fused_moe_state=_get_fused_moe_state(
|
||||
ep_size, is_prefill, is_deepseek_v3_r1))
|
||||
with patch(
|
||||
"vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context",
|
||||
return_value=forward_context):
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=global_num_experts,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("others_param", [16, 1, 4])
|
||||
def test_apply_with_expert_map(self, moe_method, mock_dist_env,
|
||||
mock_moe_env, others_param):
|
||||
"""
|
||||
1 test use_select_experts and use fused_expters_with_mc2
|
||||
2 test use_select_experts and fused_experts_with_all2all_buffer
|
||||
3 test use_select_experts and fused_experts_with_all2all
|
||||
4 test use_select_experts and fused_experts
|
||||
"""
|
||||
ep_size = others_param
|
||||
is_prefill = False
|
||||
forward_context = MagicMock(
|
||||
fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True))
|
||||
with patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_forward_context", return_value=forward_context), \
|
||||
patch("vllm_ascend.torchair.ops.torchair_fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3):
|
||||
expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1])
|
||||
moe_method.ep_size = ep_size
|
||||
x = torch.randn(8, 2, 2)
|
||||
if ep_size == 1:
|
||||
x = x.view(-1, 2)
|
||||
router_logits = torch.randn(8, 8)
|
||||
layer = MagicMock()
|
||||
layer.w13_weight = torch.randn(8, 16, 1)
|
||||
layer.w2_weight = torch.randn(16, 8, 1)
|
||||
result = moe_method.apply(layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=128,
|
||||
expert_map=expert_map,
|
||||
is_prefill=is_prefill)
|
||||
|
||||
if ep_size == 16 or ep_size == 1:
|
||||
assert result.shape == (16, 2)
|
||||
else:
|
||||
assert result.shape == x.shape
|
||||
332
tests/ut/torchair/ops/test_torchair_rotary_embedding.py
Normal file
332
tests/ut/torchair/ops/test_torchair_rotary_embedding.py
Normal file
@@ -0,0 +1,332 @@
|
||||
import math
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.ops.torchair_rotary_embedding import (
|
||||
custom_rotary_embedding_enabled, native_rope_deepseek_forward,
|
||||
rope_forward_oot, rotate_half, yarn_find_correction_dim, yarn_get_mscale)
|
||||
|
||||
|
||||
class TestCustomRotaryEmbeddingEnabled(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
def test_custom_rotary_embedding_enabled(self):
|
||||
# Test when all conditions are True
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Test when dtype is not float16
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
query = self.query.to(torch.float32)
|
||||
result = custom_rotary_embedding_enabled(query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when neox_style is False
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, False,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when head_size is not divisible by 32
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=True):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size + 1)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Test when custom op is disabled
|
||||
with patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.enable_custom_op',
|
||||
return_value=False):
|
||||
result = custom_rotary_embedding_enabled(self.query, True,
|
||||
self.head_size)
|
||||
self.assertFalse(result)
|
||||
|
||||
|
||||
class TestRopeForwardOot(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
# Common setup for tests
|
||||
self.positions = torch.tensor([1, 2, 3])
|
||||
self.query = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.key = torch.randn(3, 4, dtype=torch.float16)
|
||||
self.head_size = 32
|
||||
self.cos_sin_cache = torch.randn(3, 4)
|
||||
|
||||
# Mock self object for rope_forward_oot
|
||||
self.mock_self = MagicMock()
|
||||
self.mock_self.head_size = self.head_size
|
||||
self.mock_self.cos_sin_cache = self.cos_sin_cache
|
||||
self.mock_self.is_neox_style = True
|
||||
self.mock_self.forward_native.return_value = (self.query, self.key)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_torchair_enabled_base(self,
|
||||
mock_get_ascend_config):
|
||||
# Setup mock for torchair enabled
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = True
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.mock_self.forward_native.assert_called_once_with(
|
||||
self.positions, self.query, self.key, None)
|
||||
self.assertTrue(torch.equal(result_q, self.query))
|
||||
self.assertTrue(torch.equal(result_k, self.key))
|
||||
|
||||
@patch('torch.ops._C')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch('vllm_ascend.torchair.ops.torchair_rotary_embedding.is_310p',
|
||||
return_value=False)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=True)
|
||||
@patch('torch.ops._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_custom_kernel(self, mock_rotary_embedding,
|
||||
mock_custom_enabled, mock_is_310p,
|
||||
mock_get_ascend_config, mock__c):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Setup mock for custom kernel path
|
||||
|
||||
mock__c.rotary_embedding.return_value = self.query, self.key
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
self.query, self.key)
|
||||
|
||||
self.assertEqual(result_q.shape, self.query.shape)
|
||||
self.assertEqual(result_k.shape, self.key.shape)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_contiguous(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test contiguous path when custom is disabled
|
||||
non_contig_query = self.query.transpose(0, 1)
|
||||
non_contig_key = self.key.transpose(0, 1)
|
||||
|
||||
result_q, result_k = rope_forward_oot(self.mock_self, self.positions,
|
||||
non_contig_query, non_contig_key)
|
||||
|
||||
mock_npu_rotary.assert_called_once()
|
||||
self.assertEqual(result_q.shape, non_contig_query.shape)
|
||||
self.assertEqual(result_k.shape, non_contig_key.shape)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
def test_rope_forward_oot_with_offsets(self, mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test that NotImplementedError is raised when offsets is provided
|
||||
offsets = torch.tensor([1, 2, 3])
|
||||
with self.assertRaises(NotImplementedError):
|
||||
rope_forward_oot(self.mock_self, self.positions, self.query,
|
||||
self.key, offsets)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.get_ascend_config')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.custom_rotary_embedding_enabled',
|
||||
return_value=False)
|
||||
@patch('torch_npu._npu_rotary_embedding')
|
||||
def test_rope_forward_oot_neox_style_override(self, mock_npu_rotary,
|
||||
mock_custom_enabled,
|
||||
mock_get_ascend_config):
|
||||
mock_config = MagicMock()
|
||||
mock_config.torchair_graph_config.enabled = False
|
||||
mock_get_ascend_config.return_value = mock_config
|
||||
|
||||
# Test neox_style override
|
||||
result_q, result_k = rope_forward_oot(self.mock_self,
|
||||
self.positions,
|
||||
self.query,
|
||||
self.key,
|
||||
is_neox_style_override=False)
|
||||
|
||||
# Check that neox_style=False was passed to the NPU function
|
||||
args, kwargs = mock_npu_rotary.call_args
|
||||
self.assertFalse(args[-1])
|
||||
|
||||
|
||||
class MockRopeModule:
|
||||
|
||||
def __init__(self, max_seq_len=2048, is_neox_style=True):
|
||||
self.max_seq_len = max_seq_len
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cos_cached = None
|
||||
self.sin_cached = None
|
||||
self.rotary_dim = 1
|
||||
self.base = 1
|
||||
|
||||
|
||||
class TestNativeRopeDeepseekForward(TestBase):
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_base(self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding._set_cos_sin_cache'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_cache_handling(
|
||||
self, mock_rope_forward_oot, mock_set_cache):
|
||||
# Test cache situation is true
|
||||
module = MockRopeModule(max_seq_len=1024)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module,
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
max_seq_len=2048)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_key_reshaping(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule()
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == (1, 128)
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.ops.torchair_rotary_embedding.rope_forward_oot')
|
||||
def test_native_rope_deepseek_forward_non_neox_style(
|
||||
self, mock_rope_forward_oot):
|
||||
module = MockRopeModule(is_neox_style=False)
|
||||
positions = torch.tensor([1, 2, 3])
|
||||
query = torch.randn(1, 8, 128)
|
||||
key = torch.randn(1, 8, 128)
|
||||
|
||||
mock_rope_forward_oot.return_value = (query, key)
|
||||
|
||||
q_pe, k_pe = native_rope_deepseek_forward(module, positions, query,
|
||||
key)
|
||||
|
||||
assert q_pe.shape == query.shape
|
||||
assert k_pe.shape == key.shape
|
||||
|
||||
|
||||
class TestRotateHalf(TestBase):
|
||||
|
||||
def test_rotate_half_even_dim(self):
|
||||
# Test with even dimension
|
||||
x = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
expected = torch.tensor([-3.0, -4.0, 1.0, 2.0])
|
||||
result = rotate_half(x)
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnFindCorrectionDim(TestBase):
|
||||
|
||||
def test_basic_case(self):
|
||||
# Test with standard values
|
||||
num_rotations = 100
|
||||
dim = 512
|
||||
base = 10000
|
||||
max_position_embeddings = 2048
|
||||
|
||||
result = yarn_find_correction_dim(num_rotations, dim, base,
|
||||
max_position_embeddings)
|
||||
|
||||
# Calculate expected value manually
|
||||
expected = (dim * torch.log(
|
||||
torch.tensor(max_position_embeddings) /
|
||||
(num_rotations * 2 * torch.pi))) / (2 *
|
||||
torch.log(torch.tensor(base)))
|
||||
|
||||
self.assertTrue(torch.allclose(result, expected))
|
||||
|
||||
|
||||
class TestYarnGetMscale(TestBase):
|
||||
|
||||
def test_scale_less_than_or_equal_1(self):
|
||||
self.assertEqual(yarn_get_mscale(scale=0.5), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=1.0), 1.0)
|
||||
self.assertEqual(yarn_get_mscale(scale=0.999), 1.0)
|
||||
|
||||
def test_scale_greater_than_1(self):
|
||||
test_cases = [(2.0, 1.0, 1.0 + 0.1 * math.log(2.0)),
|
||||
(10.0, 1.0, 1.0 + 0.1 * math.log(10.0)),
|
||||
(5.0, 2.0, 1.0 + 0.2 * math.log(5.0)),
|
||||
(math.e, 1.0, 1.0 + 0.1)]
|
||||
|
||||
for scale, mscale, expected in test_cases:
|
||||
result = yarn_get_mscale(scale, mscale)
|
||||
self.assertAlmostEqual(
|
||||
result,
|
||||
expected,
|
||||
places=6,
|
||||
msg=f"Failed for scale={scale}, mscale={mscale}")
|
||||
176
tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Normal file
176
tests/ut/torchair/quantization/test_torchair_w4a8_dynamic.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import copy
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w4a8_dynamic import (
|
||||
TorchairAscendW4A8DynamicFusedMoEMethod,
|
||||
TorchairAscendW4A8DynamicLinearMethod)
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicLinearMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.method = TorchairAscendW4A8DynamicLinearMethod()
|
||||
self.method.group_size = 8
|
||||
|
||||
def test_get_weight(self):
|
||||
weight = self.method.get_weight(8, 32, torch.bfloat16)
|
||||
self.assertEqual(weight["weight"].dtype, torch.int8)
|
||||
self.assertEqual(weight["weight"].shape, (32, 8))
|
||||
|
||||
def test_get_pergroup_param(self):
|
||||
params = self.method.get_pergroup_param(8, 32, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_scale_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_scale_second"].shape, (32, 1))
|
||||
self.assertEqual(params["weight_offset_second"].dtype, torch.bfloat16)
|
||||
self.assertEqual(params["weight_offset_second"].shape, (32, 1))
|
||||
|
||||
|
||||
class TestAscendW4A8DynamicFusedMoEMethod(TestBase):
|
||||
experts = 8
|
||||
input_size = 16
|
||||
output_size = 56
|
||||
group_size = 2
|
||||
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_current_vllm_config'
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_ep_group')
|
||||
@patch("vllm_ascend.ascend_config.get_ascend_config")
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w4a8_dynamic.get_mc2_group'
|
||||
)
|
||||
@patch('torch.distributed.get_rank', return_value=0)
|
||||
def setUp(self, mock_get_rank, mock_get_mc2_group, mock_get_ascend_config,
|
||||
mock_get_ep_group, get_current_vllm_config):
|
||||
mock_ascend_config = Mock()
|
||||
mock_ascend_config.torchair_graph_config = Mock(enabled=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_vllm_config = Mock()
|
||||
mock_vllm_config.quant_config = Mock(quant_description={
|
||||
"group_size": self.group_size,
|
||||
"version": "0.0.0"
|
||||
})
|
||||
mock_vllm_config.parallel_config = Mock(enable_expert_parallel=True)
|
||||
get_current_vllm_config.return_value = mock_vllm_config
|
||||
self.quant_method = TorchairAscendW4A8DynamicFusedMoEMethod()
|
||||
|
||||
def test_get_weight(self):
|
||||
# old quant version w4a8 weight
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, 2 * self.input_size, self.output_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_weight(self.experts,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight"].dtype, torch.int8)
|
||||
self.assertEqual(param_dict["w13_weight"].shape,
|
||||
(self.experts, self.input_size, self.output_size))
|
||||
|
||||
def test_get_dynamic_quant_param(self):
|
||||
# old quant version weight
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale"].shape,
|
||||
(self.experts, 2 * self.input_size, 1))
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w13_weight_scale_second"].shape,
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size))
|
||||
self.assertEqual(param_dict["w2_weight_scale"].dtype, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale"].shape,
|
||||
(self.experts, self.output_size, 1))
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].dtype,
|
||||
torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_weight_scale_second"].shape,
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size))
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
param_dict = self.quant_method.get_dynamic_quant_param(
|
||||
self.experts, self.input_size, self.output_size, torch.bfloat16)
|
||||
self.assertEqual(param_dict["w2_scale_bias"].dtype, torch.float32)
|
||||
self.assertEqual(
|
||||
param_dict["w2_scale_bias"].shape,
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size))
|
||||
|
||||
@patch('torch_npu.npu_quantize')
|
||||
@patch('torch.Tensor.npu')
|
||||
def test_process_weights_after_loading(self, mock_npu, mock_npu_quantize):
|
||||
# old quant version weight
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, 2 * self.input_size, self.output_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(torch.zeros(
|
||||
(self.experts, self.output_size, self.input_size),
|
||||
dtype=torch.int8),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, 2 * self.input_size,
|
||||
self.output_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size, 1), dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale_second = torch.nn.Parameter(torch.ones(
|
||||
(self.experts, self.output_size,
|
||||
self.input_size // self.group_size),
|
||||
dtype=torch.bfloat16),
|
||||
requires_grad=False)
|
||||
new_layer = copy.deepcopy(layer)
|
||||
|
||||
mock_npu.return_value = torch.Tensor()
|
||||
mock_npu_quantize.return_value = torch.Tensor()
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
self.assertTrue(hasattr(layer, "w13_scale_bias"))
|
||||
self.assertEqual(layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(layer.w13_scale_bias.data.dtype, torch.float32)
|
||||
self.assertTrue(hasattr(layer, "w2_scale_bias"))
|
||||
self.assertEqual(layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
self.assertEqual(layer.w2_scale_bias.data.dtype, torch.float32)
|
||||
# new quant version weight
|
||||
self.quant_method.new_quant_version = True
|
||||
new_layer.w13_weight.data = torch.zeros(
|
||||
(self.experts, self.input_size, self.output_size),
|
||||
dtype=torch.int8)
|
||||
new_layer.w2_weight.data = torch.zeros(
|
||||
(self.experts, self.output_size // 2, self.input_size),
|
||||
dtype=torch.int8)
|
||||
w13_scale_bias = torch.zeros((self.experts, 2 * self.input_size, 1),
|
||||
dtype=torch.float32)
|
||||
new_layer.w13_scale_bias = torch.nn.Parameter(w13_scale_bias,
|
||||
requires_grad=False)
|
||||
w2_scale_bias = torch.zeros(
|
||||
(self.experts, self.output_size, 16 // self.quant_method.tp_size),
|
||||
dtype=torch.float32)
|
||||
new_layer.w2_scale_bias = torch.nn.Parameter(w2_scale_bias,
|
||||
requires_grad=False)
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
self.assertEqual(new_layer.w13_scale_bias.data.shape,
|
||||
(self.experts, 2 * self.input_size))
|
||||
self.assertEqual(new_layer.w2_scale_bias.data.shape,
|
||||
(self.experts, self.output_size))
|
||||
75
tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py
Normal file
75
tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
torchair_fused_experts_with_all2all
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
|
||||
def setUp(self):
|
||||
self.hidden_size = 128
|
||||
self.num_tokens = 128
|
||||
self.placeholder = torch.randn(self.num_tokens,
|
||||
self.hidden_size,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
@patch("torch.distributed.all_to_all_single")
|
||||
@patch("torch_npu.npu_moe_re_routing")
|
||||
@patch("torch_npu.npu_grouped_matmul")
|
||||
@patch("torch_npu.npu_swiglu")
|
||||
@patch("torch_npu.npu_dynamic_quant")
|
||||
@patch("torch_npu.npu_moe_finalize_routing")
|
||||
@patch("torch_npu.npu_moe_init_routing")
|
||||
def test_torchair_fused_experts_with_all2all(
|
||||
self, mock_moe_init_routing, mock_moe_finalize_routing,
|
||||
mock_dynamic_quant, mock_swiglu, mock_grouped_matmul,
|
||||
mock_moe_re_routing, mock_all_to_all_single):
|
||||
|
||||
expert_map = MagicMock()
|
||||
ep_group = MagicMock()
|
||||
placeholder_int8 = torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, self.hidden_size),
|
||||
dtype=torch.int8)
|
||||
placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32)
|
||||
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
|
||||
input)
|
||||
mock_moe_init_routing.return_value = (
|
||||
placeholder_int8,
|
||||
placeholder_ones,
|
||||
placeholder_ones,
|
||||
)
|
||||
mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder,
|
||||
torch.randint(0,
|
||||
100,
|
||||
(self.num_tokens, ),
|
||||
dtype=torch.int32),
|
||||
self.placeholder)
|
||||
mock_grouped_matmul.return_value = self.placeholder
|
||||
mock_swiglu.return_value = self.placeholder
|
||||
mock_dynamic_quant.return_value = (
|
||||
placeholder_int8,
|
||||
torch.randn(self.num_tokens),
|
||||
)
|
||||
mock_moe_finalize_routing.return_value = self.placeholder
|
||||
|
||||
result = torchair_fused_experts_with_all2all(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=8,
|
||||
expert_map=expert_map,
|
||||
ep_group=ep_group,
|
||||
log2phy=None,
|
||||
global_redundant_expert_num=256,
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.shape, (128, 128))
|
||||
817
tests/ut/torchair/test_torchair_mla.py
Normal file
817
tests/ut/torchair/test_torchair_mla.py
Normal file
@@ -0,0 +1,817 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.torchair.torchair_mla import (
|
||||
AscendMLATorchairBackend, AscendMLATorchairDecodeMetadata,
|
||||
AscendMLATorchairImpl, AscendMLATorchairMetadata,
|
||||
AscendMLATorchairMetadataBuilder, AscendMLATorchairPrefillMetadata)
|
||||
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
|
||||
|
||||
|
||||
class TestAscendMLATorchairBackend(TestBase):
|
||||
|
||||
def test_get_name(self):
|
||||
self.assertEqual(AscendMLATorchairBackend.get_name(),
|
||||
"ASCEND_MLA_TORCHAIR")
|
||||
|
||||
def test_get_metadata_cls(self):
|
||||
self.assertEqual(AscendMLATorchairBackend.get_metadata_cls(),
|
||||
AscendMLATorchairMetadata)
|
||||
|
||||
def test_get_builder_cls(self):
|
||||
self.assertEqual(AscendMLATorchairBackend.get_builder_cls(),
|
||||
AscendMLATorchairMetadataBuilder)
|
||||
|
||||
def test_get_kv_cache_shape(self):
|
||||
result = AscendMLATorchairBackend.get_kv_cache_shape(2, 4, 8, 128)
|
||||
self.assertEqual(result, (2, 4, 8, 128))
|
||||
|
||||
def test_get_impl_cls(self):
|
||||
result = AscendMLATorchairBackend.get_impl_cls()
|
||||
self.assertEqual(result, AscendMLATorchairImpl)
|
||||
|
||||
|
||||
class TestAscendMLATorchairPrefillMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_prefill_metadata_default(self):
|
||||
attn_mask = torch.tensor([[1, 0], [1, 1]], dtype=torch.bool)
|
||||
query_lens = [1, 2]
|
||||
seq_lens = [2, 2]
|
||||
context_lens = torch.tensor([1, 2])
|
||||
input_positions = torch.tensor([0, 1, 0, 1])
|
||||
query_start_loc = torch.tensor([0, 1, 3])
|
||||
block_table = torch.tensor([[0, 1], [2, 3]])
|
||||
max_query_len = 2
|
||||
max_seq_lens = 2
|
||||
|
||||
metadata = AscendMLATorchairPrefillMetadata(
|
||||
attn_mask=attn_mask,
|
||||
query_lens=query_lens,
|
||||
seq_lens=seq_lens,
|
||||
context_lens=context_lens,
|
||||
input_positions=input_positions,
|
||||
query_start_loc=query_start_loc,
|
||||
block_table=block_table,
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens)
|
||||
self.assertIs(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.context_lens, context_lens)
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertEqual(metadata.max_query_len, max_query_len)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertIsNone(metadata.chunked_context)
|
||||
|
||||
def test_ascend_mla_prefill_metadata_with_chunked_context(self):
|
||||
cu_seq_lens = torch.tensor([0, 2, 4])
|
||||
starts = torch.tensor([0, 2])
|
||||
seq_tot = [2, 2]
|
||||
max_seq_lens = [2, 2]
|
||||
workspace = torch.randn(2, 4)
|
||||
chunk_seq_lens = torch.tensor([2, 2])
|
||||
|
||||
chunked_context = AscendMLATorchairPrefillMetadata.TorchairChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens,
|
||||
starts=starts,
|
||||
seq_tot=seq_tot,
|
||||
max_seq_lens=max_seq_lens,
|
||||
workspace=workspace,
|
||||
chunk_seq_lens=chunk_seq_lens)
|
||||
|
||||
metadata = AscendMLATorchairPrefillMetadata(
|
||||
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
|
||||
query_lens=[1, 2],
|
||||
seq_lens=[2, 2],
|
||||
context_lens=torch.tensor([1, 2]),
|
||||
input_positions=torch.tensor([0, 1, 0, 1]),
|
||||
query_start_loc=torch.tensor([0, 1, 3]),
|
||||
block_table=torch.tensor([[0, 1], [2, 3]]),
|
||||
max_query_len=2,
|
||||
max_seq_lens=2,
|
||||
chunked_context=chunked_context)
|
||||
|
||||
self.assertIsNotNone(metadata.chunked_context)
|
||||
self.assertIs(metadata.chunked_context.cu_seq_lens, cu_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.starts, starts)
|
||||
self.assertEqual(metadata.chunked_context.seq_tot, seq_tot)
|
||||
self.assertEqual(metadata.chunked_context.max_seq_lens, max_seq_lens)
|
||||
self.assertIs(metadata.chunked_context.workspace, workspace)
|
||||
self.assertIs(metadata.chunked_context.chunk_seq_lens, chunk_seq_lens)
|
||||
|
||||
|
||||
class TestAscendMLATorchairDecodeMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_decode_metadata_default(self):
|
||||
input_positions = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
|
||||
block_table = torch.tensor([[0, 3, 2, 1], [0, 2, 1, 3]])
|
||||
seq_lens = torch.tensor([[2], [3]])
|
||||
max_seq_lens = 4
|
||||
seq_lens_list = [2, 3]
|
||||
attn_mask = None
|
||||
|
||||
metadata = AscendMLATorchairDecodeMetadata(input_positions,
|
||||
block_table, seq_lens,
|
||||
max_seq_lens, seq_lens_list,
|
||||
attn_mask)
|
||||
|
||||
self.assertIs(metadata.input_positions, input_positions)
|
||||
self.assertIs(metadata.block_table, block_table)
|
||||
self.assertIs(metadata.seq_lens, seq_lens)
|
||||
self.assertEqual(metadata.max_seq_lens, max_seq_lens)
|
||||
self.assertEqual(metadata.seq_lens_list, seq_lens_list)
|
||||
self.assertIsNone(attn_mask)
|
||||
|
||||
|
||||
class TestAscendMLATorchairMetadata(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_default(self):
|
||||
num_actual_tokens = 100
|
||||
slot_mapping = torch.randn(100, 4, 1024)
|
||||
query_start_loc = torch.tensor([1, 2, 3, 4])
|
||||
seq_lens = [30, 50]
|
||||
block_tables = torch.randint(0, 100, (100, 4))
|
||||
|
||||
num_decodes = 4
|
||||
num_decode_tokens = 8
|
||||
num_prefills = 8
|
||||
|
||||
num_input_tokens = 2
|
||||
|
||||
query_lens = None
|
||||
head_dim = None
|
||||
attn_mask = None
|
||||
attn_state = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode = None
|
||||
prefill = None
|
||||
|
||||
metadata = AscendMLATorchairMetadata(
|
||||
num_actual_tokens, slot_mapping, query_start_loc, seq_lens,
|
||||
block_tables, num_decodes, num_decode_tokens, num_prefills,
|
||||
num_input_tokens, query_lens, head_dim, attn_mask, attn_state,
|
||||
decode, prefill)
|
||||
|
||||
self.assertEqual(metadata.num_actual_tokens, num_actual_tokens)
|
||||
self.assertIs(metadata.slot_mapping, slot_mapping)
|
||||
self.assertIs(metadata.query_start_loc, query_start_loc)
|
||||
self.assertEqual(metadata.seq_lens, seq_lens)
|
||||
self.assertIs(metadata.block_tables, block_tables)
|
||||
self.assertEqual(metadata.num_decodes, num_decodes)
|
||||
self.assertEqual(metadata.num_decode_tokens, num_decode_tokens)
|
||||
self.assertEqual(metadata.num_prefills, num_prefills)
|
||||
self.assertEqual(metadata.num_input_tokens, num_input_tokens)
|
||||
self.assertEqual(metadata.query_lens, query_lens)
|
||||
self.assertEqual(metadata.head_dim, head_dim)
|
||||
self.assertEqual(metadata.attn_mask, attn_mask)
|
||||
self.assertEqual(metadata.attn_state, attn_state)
|
||||
self.assertEqual(metadata.decode, decode)
|
||||
self.assertEqual(metadata.prefill, prefill)
|
||||
|
||||
|
||||
class TestAscendMLATorchairMetadataBuilder(TestBase):
|
||||
|
||||
def test_ascend_mla_metadata_builder_default(self):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.model_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
ascend_config = MagicMock()
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
self.assertEqual(builder.block_size,
|
||||
mock_vllm_config.cache_config.block_size)
|
||||
self.assertEqual(
|
||||
builder.chunked_prefill_enabled,
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled)
|
||||
self.assertEqual(builder.torchair_graph_enabled, True)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_reorder_batch_with_torchair_graph(self, ascend_config):
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = {0: 2, 1: 1, 2: 3, 3: 1}
|
||||
scheduler_output.scheduled_spec_decode_tokens = {
|
||||
0: [1],
|
||||
1: [],
|
||||
2: [1, 1],
|
||||
3: []
|
||||
}
|
||||
|
||||
input_batch.swap_states = MagicMock()
|
||||
|
||||
modified = builder.reorder_batch(input_batch, scheduler_output)
|
||||
|
||||
self.assertFalse(modified)
|
||||
input_batch.swap_states.assert_not_called()
|
||||
|
||||
def test_reorder_batch_without_torchair_graph(self):
|
||||
ascend_config = MagicMock()
|
||||
ascend_config.torchair_graph_config = MagicMock()
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.max_num_seqs = 4
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
with patch("vllm_ascend.torchair.torchair_mla.get_ascend_config",
|
||||
return_value=ascend_config):
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
input_batch = MagicMock()
|
||||
input_batch.req_ids = [0, 1, 2, 3]
|
||||
|
||||
scheduler_output = MagicMock()
|
||||
scheduler_output.num_scheduled_tokens = {0: 1, 1: 3, 2: 1, 3: 2}
|
||||
scheduler_output.scheduled_spec_decode_tokens = {
|
||||
0: [],
|
||||
1: [1],
|
||||
2: [],
|
||||
3: []
|
||||
}
|
||||
|
||||
input_batch.swap_states = MagicMock()
|
||||
|
||||
modified = builder.reorder_batch(input_batch, scheduler_output)
|
||||
|
||||
self.assertTrue(modified)
|
||||
input_batch.swap_states.assert_called_once_with(1, 2)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_normal(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 64)
|
||||
self.assertTrue(torch.equal(result[:, :10], block_tables))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_truncated(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 64
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_device)
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 4)
|
||||
self.assertTrue(torch.equal(result, block_tables[:, :4]))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_get_graph_runner_block_tables_from_numpy(self,
|
||||
mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(mock_vllm_config,
|
||||
mock_device)
|
||||
|
||||
block_tables = torch.randint(0, 100, (3, 10), dtype=torch.int32)
|
||||
|
||||
result = builder._get_graph_runner_block_tables(3, block_tables)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
self.assertEqual(result.shape[1], 64)
|
||||
self.assertTrue(torch.equal(result[:, :10], block_tables))
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_build_dummy(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_vllm_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_device = 'cpu'
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(
|
||||
mock_vllm_config,
|
||||
mock_device,
|
||||
metadata_cls=AscendMLATorchairMetadata)
|
||||
builder.rope_dim = 64
|
||||
|
||||
with patch.object(builder,
|
||||
"_get_graph_runner_block_tables",
|
||||
side_effect=lambda x, y: y):
|
||||
common_attn_metadata = TorchairCommonAttentionMetadata(
|
||||
num_reqs=3,
|
||||
num_actual_tokens=3,
|
||||
decode_token_per_req=1,
|
||||
actual_seq_lengths_q=[0, 1, 2],
|
||||
attn_mask=torch.zeros((1, 1), dtype=torch.bool),
|
||||
spec_attn_mask=torch.zeros((1, 1), dtype=torch.bool),
|
||||
)
|
||||
metadata = builder.build_torchair_graph_dummy(common_attn_metadata)
|
||||
|
||||
sin_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=torch.float16,
|
||||
device=mock_device)
|
||||
cos_golden = torch.ones(3,
|
||||
1,
|
||||
1,
|
||||
64,
|
||||
dtype=torch.float16,
|
||||
device=mock_device)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 3)
|
||||
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||
self.assertEqual(metadata.num_decodes, 1)
|
||||
self.assertEqual(metadata.num_decode_tokens, 1)
|
||||
self.assertEqual(metadata.num_prefills, 0)
|
||||
self.assertEqual(metadata.attn_state, AscendAttentionState.DecodeOnly)
|
||||
self.assertIsNone(metadata.prefill)
|
||||
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
|
||||
self.assertEqual(metadata.block_tables.shape[0], 3)
|
||||
self.assertEqual(metadata.block_tables.shape[1], 64)
|
||||
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||
self.assertEqual(metadata.query_start_loc.shape[0], 3)
|
||||
assert torch.equal(sin_golden, metadata.decode.sin)
|
||||
assert torch.equal(cos_golden, metadata.decode.cos)
|
||||
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def test_build_decode(self, mock_ascend_config):
|
||||
ascend_config = MagicMock()
|
||||
mock_ascend_config.return_value = ascend_config
|
||||
ascend_config.torchair_graph_config.enabled = False
|
||||
|
||||
mock_vllm_config = MagicMock()
|
||||
mock_vllm_config.model_config.max_model_len = 1024
|
||||
mock_vllm_config.cache_config.block_size = 16
|
||||
mock_vllm_config.scheduler_config.chunked_prefill_enabled = False
|
||||
mock_vllm_config.get_head_size.return_value = 64
|
||||
mock_vllm_config.model_config.dtype = torch.float16
|
||||
mock_device = 'cpu'
|
||||
model = MagicMock(spec=nn.Module)
|
||||
model.model = MagicMock(spec=nn.Module)
|
||||
|
||||
builder = AscendMLATorchairMetadataBuilder(
|
||||
mock_vllm_config,
|
||||
mock_device,
|
||||
metadata_cls=AscendMLATorchairMetadata)
|
||||
builder.rope_dim = 64
|
||||
|
||||
builder.sin_cache = torch.tensor([10, 10])
|
||||
builder.cos_cache = torch.tensor([10, 10])
|
||||
|
||||
with patch.object(builder,
|
||||
"_get_graph_runner_block_tables",
|
||||
side_effect=lambda x, y: y):
|
||||
common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=torch.tensor([0, 1, 2, 3]),
|
||||
query_start_loc_cpu=torch.tensor([0, 1, 2, 3]),
|
||||
seq_lens_cpu=torch.tensor([1, 1, 1]),
|
||||
num_reqs=3,
|
||||
num_actual_tokens=3,
|
||||
max_query_len=1,
|
||||
decode_token_per_req=torch.tensor([1, 1, 1]),
|
||||
block_table_tensor=torch.zeros((10, 10)),
|
||||
slot_mapping_cpu=torch.tensor(range(20)),
|
||||
actual_seq_lengths_q=torch.tensor([0, 1, 2]),
|
||||
positions=torch.tensor([1, 1]),
|
||||
attn_mask=torch.ones((15, 15)),
|
||||
spec_attn_mask=None,
|
||||
attn_state=AscendAttentionState.ChunkedPrefill)
|
||||
|
||||
metadata = builder.build(common_attn_metadata, model)
|
||||
|
||||
self.assertIsInstance(metadata, AscendMLATorchairMetadata)
|
||||
self.assertEqual(metadata.num_input_tokens, 0)
|
||||
self.assertEqual(metadata.num_actual_tokens, 3)
|
||||
self.assertEqual(metadata.num_decodes, 3)
|
||||
self.assertEqual(metadata.num_decode_tokens, 3)
|
||||
self.assertEqual(metadata.num_prefills, 0)
|
||||
self.assertEqual(metadata.attn_state,
|
||||
AscendAttentionState.ChunkedPrefill)
|
||||
self.assertIsNone(metadata.prefill)
|
||||
self.assertIsInstance(metadata.decode, AscendMLATorchairDecodeMetadata)
|
||||
self.assertEqual(metadata.block_tables.shape[0], 3)
|
||||
self.assertEqual(metadata.block_tables.shape[1], 10)
|
||||
self.assertEqual(metadata.seq_lens.shape[0], 3)
|
||||
self.assertEqual(metadata.slot_mapping.shape[0], 3)
|
||||
self.assertEqual(metadata.query_start_loc.shape[0], 4)
|
||||
|
||||
|
||||
class TestAscendMLATorchairImpl(TestBase):
|
||||
|
||||
@patch('vllm.distributed.parallel_state._TP',
|
||||
new_callable=lambda: MagicMock(spec=GroupCoordinator))
|
||||
@patch("vllm.distributed.get_tensor_model_parallel_world_size",
|
||||
return_value=2)
|
||||
@patch("vllm.config.get_current_vllm_config")
|
||||
@patch("vllm_ascend.torchair.torchair_mla.get_ascend_config")
|
||||
def setUp(self, ascend_config, vllm_config, mock_get_tp_size, mock_tp):
|
||||
mock_tp.world_size = 2
|
||||
ascend_config.torchair_graph_config.enabled = True
|
||||
ascend_config.torchair_graph_config.enable_kv_nz = False
|
||||
speculative_config = MagicMock()
|
||||
speculative_config.num_speculative_tokens = 4
|
||||
vllm_config.speculative_config = speculative_config
|
||||
|
||||
num_heads = 256
|
||||
head_size = 1024
|
||||
scale = 0.1
|
||||
num_kv_heads = 8
|
||||
kv_cache_dtype = "auto"
|
||||
|
||||
kv_a_layernorm = MagicMock()
|
||||
kv_a_layernorm.weight = torch.randn(96)
|
||||
kv_a_layernorm.variance_epsilon = 1e-6
|
||||
kwargs = {
|
||||
"q_lora_rank": 64,
|
||||
"kv_lora_rank": 32,
|
||||
"qk_nope_head_dim": 64,
|
||||
"qk_rope_head_dim": 32,
|
||||
"qk_head_dim": 96,
|
||||
"v_head_dim": 128,
|
||||
"rotary_emb": MagicMock(),
|
||||
"q_proj": MagicMock(),
|
||||
"kv_b_proj": MagicMock(),
|
||||
"o_proj": MagicMock(),
|
||||
"kv_a_proj_with_mqa": MagicMock(),
|
||||
"kv_a_layernorm": kv_a_layernorm,
|
||||
}
|
||||
|
||||
self.impl = AscendMLATorchairImpl(num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=None,
|
||||
sliding_window=None,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
blocksparse_params=None,
|
||||
logits_soft_cap=None,
|
||||
attn_type=None,
|
||||
kv_sharing_target_layer_name=None,
|
||||
**kwargs)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.impl.num_heads, 256)
|
||||
self.assertEqual(self.impl.head_size, 1024)
|
||||
self.assertEqual(self.impl.scale, 0.1)
|
||||
self.assertEqual(self.impl.num_kv_heads, 8)
|
||||
self.assertEqual(self.impl.kv_cache_dtype, "auto")
|
||||
self.assertEqual(self.impl.q_lora_rank, 64)
|
||||
self.assertEqual(self.impl.kv_lora_rank, 32)
|
||||
self.assertEqual(self.impl.qk_nope_head_dim, 64)
|
||||
self.assertEqual(self.impl.qk_rope_head_dim, 32)
|
||||
self.assertEqual(self.impl.qk_head_dim, 96)
|
||||
self.assertEqual(self.impl.v_head_dim, 128)
|
||||
self.assertIsNotNone(self.impl.rotary_emb)
|
||||
self.assertIsNotNone(self.impl.q_proj)
|
||||
self.assertIsNotNone(self.impl.kv_b_proj)
|
||||
self.assertIsNotNone(self.impl.o_proj)
|
||||
self.assertIsNotNone(self.impl.kv_a_proj_with_mqa)
|
||||
self.assertIsNotNone(self.impl.kv_a_layernorm)
|
||||
self.assertEqual(self.impl.num_queries_per_kv, 32)
|
||||
self.assertEqual(self.impl.tp_size, 2)
|
||||
self.assertTrue(self.impl.torchair_graph_enabled)
|
||||
|
||||
def test_v_up_proj_and_o_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
|
||||
self.impl.o_proj.return_value = (torch.randn(
|
||||
batch_size, self.impl.num_heads * self.impl.v_head_dim), )
|
||||
if not hasattr(self.impl, 'W_UV') or self.impl.W_UV is None:
|
||||
self.impl.W_UV = torch.randn(self.impl.num_heads,
|
||||
self.impl.kv_lora_rank,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._v_up_proj_and_o_proj(x)
|
||||
|
||||
self.assertEqual(result.shape[0], batch_size)
|
||||
self.assertEqual(result.shape[1],
|
||||
self.impl.num_heads * self.impl.v_head_dim)
|
||||
|
||||
def test_q_proj_and_k_up_proj(self):
|
||||
batch_size = 4
|
||||
x = torch.randn(batch_size, self.impl.num_heads, self.impl.qk_head_dim)
|
||||
q_proj_output = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
self.impl.q_proj.return_value = (q_proj_output, )
|
||||
if not hasattr(self.impl, 'W_UK_T') or self.impl.W_UK_T is None:
|
||||
self.impl.W_UK_T = torch.randn(self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim,
|
||||
self.impl.kv_lora_rank)
|
||||
result = self.impl._q_proj_and_k_up_proj(x)
|
||||
ql_nope, q_pe = result
|
||||
self.assertEqual(ql_nope.shape[0], batch_size)
|
||||
self.assertEqual(ql_nope.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(ql_nope.shape[2], self.impl.kv_lora_rank)
|
||||
self.assertEqual(q_pe.shape[0], batch_size)
|
||||
self.assertEqual(q_pe.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(q_pe.shape[2], self.impl.qk_rope_head_dim)
|
||||
|
||||
def test_process_weights_after_loading(self):
|
||||
layer = MagicMock(spec=LinearBase)
|
||||
layer.input_size_per_partition = 10
|
||||
quant_method = MagicMock()
|
||||
apply = MagicMock()
|
||||
quant_method.apply = apply
|
||||
layer.quant_method = quant_method
|
||||
shape_0 = self.impl.num_heads * (self.impl.qk_nope_head_dim +
|
||||
self.impl.v_head_dim)
|
||||
shape_1 = self.impl.kv_lora_rank
|
||||
layer.weight = torch.randn(shape_0, shape_1)
|
||||
self.impl.kv_b_proj = layer
|
||||
apply.return_value = layer.weight.T
|
||||
self.impl.process_weights_after_loading(torch.bfloat16)
|
||||
|
||||
self.assertEqual(self.impl.W_UK_T.shape[0], self.impl.num_heads)
|
||||
self.assertEqual(self.impl.W_UK_T.shape[1], self.impl.qk_nope_head_dim)
|
||||
self.assertEqual(self.impl.W_UK_T.shape[2], self.impl.kv_lora_rank)
|
||||
|
||||
self.assertEqual(self.impl.W_UV.shape[0], self.impl.num_heads)
|
||||
self.assertEqual(self.impl.W_UV.shape[1], self.impl.kv_lora_rank)
|
||||
self.assertEqual(self.impl.W_UV.shape[2], self.impl.v_head_dim)
|
||||
|
||||
def test_compute_prefill_context_none(self):
|
||||
batch_size = 4
|
||||
kv_cache = torch.randn(10, 1, 1, 192)
|
||||
query = torch.randn(batch_size, self.impl.num_heads,
|
||||
self.impl.qk_head_dim)
|
||||
metadata = MagicMock()
|
||||
metadata.prefill = None
|
||||
prefix_out = torch.randn(2, 16, 128)
|
||||
prefix_lse = torch.randn(2, 16, 8)
|
||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
||||
metadata, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
self.assertTrue(torch.equal(prefix_out, out))
|
||||
self.assertTrue(torch.equal(prefix_lse, lse))
|
||||
|
||||
@patch("torch_npu.atb.npu_paged_cache_load")
|
||||
@patch("torch_npu.atb.npu_ring_mla")
|
||||
def test_compute_prefill_context(self, mock_ring, mock_load):
|
||||
S, N, D, VD = 2, self.impl.num_heads, self.impl.qk_head_dim, self.impl.v_head_dim
|
||||
_, AND = self.impl.qk_rope_head_dim, self.impl.qk_nope_head_dim
|
||||
latent_kv_dim = self.impl.kv_lora_rank
|
||||
num_blocks, block_size = 100, 20
|
||||
query = torch.randn(S, N, D)
|
||||
kv_cache_0 = torch.randn(num_blocks, block_size, N, latent_kv_dim)
|
||||
kv_cache_1 = torch.randn(num_blocks, block_size, N, D)
|
||||
kv_cache = [kv_cache_0, kv_cache_1]
|
||||
prefix_out = torch.randn(S, N, 128)
|
||||
prefix_lse = torch.randn(S, N)
|
||||
|
||||
self.impl.kv_b_proj.return_value = (torch.randn(8, N, VD + AND), )
|
||||
|
||||
chunk_ctx = MagicMock()
|
||||
chunk_ctx.seq_tot = [8]
|
||||
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
|
||||
chunk_ctx.starts = [torch.tensor([0])]
|
||||
|
||||
prefill_meta = MagicMock()
|
||||
prefill_meta.chunked_context = chunk_ctx
|
||||
prefill_meta.query_lens = [8]
|
||||
prefill_meta.block_table = torch.randint(0, 100, (S, 4))
|
||||
|
||||
meta = MagicMock()
|
||||
meta.prefill = prefill_meta
|
||||
|
||||
out, lse = self.impl._compute_prefill_context(query, kv_cache, 32,
|
||||
meta, prefix_out,
|
||||
prefix_lse)
|
||||
|
||||
mock_load.assert_called_once()
|
||||
mock_ring.assert_called_once()
|
||||
|
||||
self.assertEqual(out.shape, prefix_out.shape)
|
||||
self.assertEqual(lse.shape, prefix_lse.shape)
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv(self, mock_kv_cache):
|
||||
batch_size = 2
|
||||
hidden = torch.randn(batch_size, 128)
|
||||
cos = torch.randn(batch_size, 32)
|
||||
sin = torch.randn(batch_size, 32)
|
||||
kv_cache = (torch.randn(
|
||||
4, 8, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
|
||||
torch.randn(
|
||||
4, 8,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
|
||||
slots = torch.arange(batch_size, dtype=torch.long)
|
||||
|
||||
proj_out = torch.randn(
|
||||
batch_size, self.impl.num_kv_heads, 1,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
|
||||
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
|
||||
|
||||
mock_kv_cache.return_value = (torch.randn(batch_size,
|
||||
self.impl.num_kv_heads, 1,
|
||||
self.impl.qk_rope_head_dim),
|
||||
torch.randn(batch_size,
|
||||
self.impl.num_kv_heads, 1,
|
||||
self.impl.kv_lora_rank),
|
||||
None, None)
|
||||
|
||||
k_pe, k_nope, kv = self.impl.exec_kv(hidden, cos, sin, kv_cache, slots)
|
||||
|
||||
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden)
|
||||
mock_kv_cache.assert_called_once()
|
||||
self.assertEqual(k_pe.shape, (batch_size, self.impl.num_kv_heads, 1,
|
||||
self.impl.qk_rope_head_dim))
|
||||
self.assertEqual(
|
||||
k_nope.shape,
|
||||
(batch_size, self.impl.num_kv_heads, 1, self.impl.kv_lora_rank))
|
||||
self.assertEqual(kv.shape,
|
||||
(batch_size, self.impl.num_kv_heads, 1,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim))
|
||||
|
||||
@patch("torch_npu.npu_kv_rmsnorm_rope_cache")
|
||||
def test_exec_kv_prefill(self, mock_kv):
|
||||
B, N, S, H = 2, self.impl.num_kv_heads, 1, 128
|
||||
hidden_states = torch.randn(B, N, S, H)
|
||||
cos = torch.randn(B, S, 32)
|
||||
sin = torch.randn(B, S, 32)
|
||||
kv_cache = (
|
||||
torch.randn(100, 8,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
|
||||
torch.randn(100, 8,
|
||||
self.impl.kv_lora_rank + self.impl.qk_rope_head_dim),
|
||||
)
|
||||
|
||||
slots = torch.arange(B * S, dtype=torch.long)
|
||||
|
||||
proj_out = torch.randn(
|
||||
B, N, S, self.impl.kv_lora_rank + self.impl.qk_rope_head_dim)
|
||||
self.impl.kv_a_proj_with_mqa.return_value = (proj_out, )
|
||||
|
||||
mock_kv.return_value = (None, None,
|
||||
torch.randn(B, self.impl.num_kv_heads, S,
|
||||
self.impl.qk_rope_head_dim),
|
||||
torch.randn(B, self.impl.num_kv_heads, S,
|
||||
self.impl.kv_lora_rank))
|
||||
|
||||
k_pe, k_nope = self.impl.exec_kv_prefill(hidden_states, cos, sin,
|
||||
kv_cache, slots)
|
||||
|
||||
self.impl.kv_a_proj_with_mqa.assert_called_once_with(hidden_states)
|
||||
mock_kv.assert_called_once()
|
||||
|
||||
self.assertEqual(
|
||||
k_pe.shape,
|
||||
(B, self.impl.num_kv_heads, S, self.impl.qk_rope_head_dim))
|
||||
self.assertEqual(
|
||||
k_nope.shape,
|
||||
(B, self.impl.num_kv_heads, S, self.impl.kv_lora_rank))
|
||||
|
||||
@patch("torch_npu.npu_interleave_rope")
|
||||
def test_rope_single(self, mock_rope):
|
||||
B, N, D = 2, 16, 1024
|
||||
x = torch.randn(B, N, D)
|
||||
cos = torch.randn(B, N, 1, D)
|
||||
sin = torch.randn(B, N, 1, D)
|
||||
mock_rope.return_value = x.view(B, N, 1, D)
|
||||
result = self.impl.rope_single(x, cos, sin)
|
||||
self.assertEqual(result.shape[0], B)
|
||||
self.assertEqual(result.shape[1], N)
|
||||
self.assertEqual(result.shape[2], D)
|
||||
mock_rope.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._v_up_proj_and_o_proj"
|
||||
)
|
||||
@patch("torch_npu._npu_paged_attention_mla")
|
||||
def test_forward_decode_without_graph(self, mock_page_attention_mla,
|
||||
mock_up_proj):
|
||||
self.impl.running_in_graph = False
|
||||
self.impl.running_chunkprefilll_with_torchair = False
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
q_nope = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_nope_head_dim)
|
||||
q_pe = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim)
|
||||
kv_c_and_k_pe_cache = torch.randn(num_blocks, block_size,
|
||||
self.impl.num_heads,
|
||||
self.impl.kv_lora_rank)
|
||||
metadata = MagicMock()
|
||||
metadata.decode = MagicMock()
|
||||
metadata.decode.block_table = MagicMock()
|
||||
metadata.decode.seq_lens = 10
|
||||
mock_page_attention_mla.return_value = torch.randn(
|
||||
num_tokens, self.impl.num_heads, self.impl.kv_lora_rank)
|
||||
mock_up_proj.return_value = torch.randn(num_tokens,
|
||||
self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
result = self.impl._forward_decode(q_nope, q_pe, None, None,
|
||||
kv_c_and_k_pe_cache, metadata)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
self.assertEqual(result.shape[1], self.impl.num_heads)
|
||||
self.assertEqual(result.shape[2], self.impl.v_head_dim)
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_page_attention_mla.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairImpl._forward_prefill"
|
||||
)
|
||||
@patch("torch_npu._npu_reshape_and_cache")
|
||||
def test_forward_without_graph(self, _, mock_forward_prefill):
|
||||
self.impl.running_in_graph = False
|
||||
self.impl.torchair_graph_enabled = False
|
||||
|
||||
num_tokens = 100
|
||||
num_blocks = 256
|
||||
block_size = 4
|
||||
rotary_emb_return_value = (torch.randn(num_tokens, 16,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(0, 1, self.impl.kv_lora_rank))
|
||||
self.impl.rotary_emb.side_effect = lambda *args, **kwargs: rotary_emb_return_value
|
||||
self.impl.o_proj.side_effect = lambda *args, **kwargs: torch.randn(
|
||||
1, num_blocks, 128)
|
||||
|
||||
hidden_states_or_q_c = torch.randn(num_tokens, self.impl.q_lora_rank)
|
||||
hidden_states_or_kv_c_normed = torch.randn(num_tokens,
|
||||
self.impl.kv_lora_rank)
|
||||
k_pe = torch.randn(num_tokens, self.impl.qk_rope_head_dim)
|
||||
kv_cache = (torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.kv_lora_rank),
|
||||
torch.randn(num_blocks, block_size, self.impl.num_heads,
|
||||
self.impl.qk_rope_head_dim))
|
||||
output = torch.randn(num_tokens, self.impl.num_heads,
|
||||
self.impl.v_head_dim)
|
||||
|
||||
metadata = MagicMock()
|
||||
metadata.num_decodes = 0
|
||||
metadata.num_prefills = num_tokens
|
||||
mock_forward_prefill.return_value = torch.randn(
|
||||
0, self.impl.num_heads * self.impl.v_head_dim)
|
||||
result = self.impl.forward(None, hidden_states_or_q_c,
|
||||
hidden_states_or_kv_c_normed, k_pe,
|
||||
kv_cache, metadata, output, False)
|
||||
self.assertEqual(result.shape[0], num_tokens)
|
||||
149
tests/ut/torchair/test_utils.py
Normal file
149
tests/ut/torchair/test_utils.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.quantizer import SUPPORT_ASCEND_QUANTIZER_TYPE
|
||||
from vllm_ascend.torchair import utils
|
||||
|
||||
|
||||
class TestTorchairUtils(TestBase):
|
||||
|
||||
def test_get_torchair_current_work_dir(self):
|
||||
cache_dir = utils.TORCHAIR_CACHE_DIR
|
||||
work_dir = utils._get_torchair_current_work_dir()
|
||||
self.assertEqual(cache_dir, work_dir)
|
||||
work_dir = utils._get_torchair_current_work_dir("test")
|
||||
self.assertEqual(os.path.join(cache_dir, "test"), work_dir)
|
||||
|
||||
def test_torchair_cache_dir(self):
|
||||
utils.write_kv_cache_bytes_to_file(0, 100)
|
||||
self.assertTrue(utils.check_torchair_cache_exist(),
|
||||
"Create torchair cache dir failed")
|
||||
self.assertTrue(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Create kv cache bytes cache dir failed")
|
||||
kv_cache_bytes = utils.read_kv_cache_bytes_from_file(0)
|
||||
self.assertEqual(100, kv_cache_bytes)
|
||||
utils.delete_torchair_cache_file()
|
||||
self.assertFalse(utils.check_torchair_cache_exist(),
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_torchair_cache_dir_multiple_ranks(self):
|
||||
ranks = [0, 1, 2, 3]
|
||||
values = [100, 200, 300, 400]
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
executor.map(utils.write_kv_cache_bytes_to_file, ranks, values)
|
||||
for rank, expected in zip(ranks, values):
|
||||
self.assertEqual(expected,
|
||||
utils.read_kv_cache_bytes_from_file(rank))
|
||||
utils.delete_torchair_cache_file()
|
||||
|
||||
self.assertFalse(utils.check_torchair_cache_exist(),
|
||||
"Delete torchair cache dir failed")
|
||||
self.assertFalse(utils.check_kv_cache_bytes_cache_exist(),
|
||||
"Delete kv cache bytes cache dir failed")
|
||||
|
||||
def test_delete_torchair_cache_file_multiple_times(self):
|
||||
utils.write_kv_cache_bytes_to_file(0, 100)
|
||||
utils.delete_torchair_cache_file()
|
||||
for i in range(5):
|
||||
try:
|
||||
utils.delete_torchair_cache_file()
|
||||
except FileNotFoundError:
|
||||
self.fail(
|
||||
f"Unexpected FileNotFoundError on delete call #{i+2}")
|
||||
|
||||
@patch('vllm.ModelRegistry')
|
||||
def test_register_torchair_model(self, mock_model_registry):
|
||||
mock_registry = MagicMock()
|
||||
mock_model_registry.return_value = mock_registry
|
||||
utils.register_torchair_model()
|
||||
|
||||
self.assertEqual(mock_model_registry.register_model.call_count, 6)
|
||||
call_args_list = mock_model_registry.register_model.call_args_list
|
||||
|
||||
expected_registrations = [
|
||||
("DeepSeekMTPModel",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP"
|
||||
),
|
||||
("DeepseekV2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM"
|
||||
),
|
||||
("DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
),
|
||||
("Qwen2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM"),
|
||||
("Qwen3MoeForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen3_moe:CustomQwen3MoeForCausalLM"
|
||||
),
|
||||
("PanguProMoEForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_pangu_moe:PanguProMoEForCausalLM"
|
||||
)
|
||||
]
|
||||
|
||||
for i, (expected_name,
|
||||
expected_path) in enumerate(expected_registrations):
|
||||
args, kwargs = call_args_list[i]
|
||||
self.assertEqual(args[0], expected_name)
|
||||
self.assertEqual(args[1], expected_path)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format(self, mock_npu_cast,
|
||||
mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = 1
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
self.assertEqual(fused_moe.w13_weight.data, 1)
|
||||
|
||||
@mock.patch('torch_npu.get_npu_format')
|
||||
@mock.patch('torch_npu.npu_format_cast')
|
||||
@mock.patch('vllm.model_executor.layers.fused_moe.layer.FusedMoE',
|
||||
new=mock.MagicMock)
|
||||
def test_converting_weight_acl_format_format_true(self, mock_npu_cast,
|
||||
mock_get_format):
|
||||
ACL_FORMAT_FRACTAL_NZ = 29
|
||||
mock_get_format.return_value = ACL_FORMAT_FRACTAL_NZ
|
||||
mock_npu_cast.return_value = 1
|
||||
|
||||
fused_moe = mock.MagicMock()
|
||||
fused_moe.w13_weight = mock.MagicMock()
|
||||
fused_moe.w2_weight = mock.MagicMock()
|
||||
fused_moe.w13_weight.data = torch.randn(128, 256)
|
||||
fused_moe.w2_weight.data = torch.randn(256, 128)
|
||||
model = mock.MagicMock()
|
||||
model.modules.return_value = [fused_moe]
|
||||
|
||||
utils.converting_weight_acl_format(model, ACL_FORMAT_FRACTAL_NZ)
|
||||
mock_npu_cast.assert_not_called()
|
||||
|
||||
def test_torchair_quant_method_register(self):
|
||||
|
||||
TorchairW8A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W8A8_DYNAMIC"]
|
||||
TorchairW4A8DYNAMICQuantizer = SUPPORT_ASCEND_QUANTIZER_TYPE[
|
||||
"W4A8_DYNAMIC"]
|
||||
utils.torchair_quant_method_register()
|
||||
self.assertNotEqual(TorchairW8A8DYNAMICQuantizer,
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W8A8_DYNAMIC"])
|
||||
self.assertNotEqual(TorchairW4A8DYNAMICQuantizer,
|
||||
SUPPORT_ASCEND_QUANTIZER_TYPE["W4A8_DYNAMIC"])
|
||||
372
tests/ut/worker/test_input_batch.py
Normal file
372
tests/ut/worker/test_input_batch.py
Normal file
@@ -0,0 +1,372 @@
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
import inspect
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
MAX_PROMPT_SIZE = 100
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def _compare_objs(obj1,
|
||||
obj2,
|
||||
skip: Sequence = ("logitsprocs", "batch_update_builder")):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
a[0] for a in attrs
|
||||
if not (a[0].startswith('__') and a[0].endswith('__'))
|
||||
])
|
||||
for attr_name in attr_names:
|
||||
if attr_name in skip:
|
||||
continue
|
||||
|
||||
a = getattr(obj1, attr_name)
|
||||
b = getattr(obj2, attr_name)
|
||||
|
||||
is_same = False
|
||||
if isinstance(a, torch.Tensor):
|
||||
if (a.numel() == 0 or b.numel() == 0):
|
||||
is_same = (a.numel() == 0 and b.numel() == 0)
|
||||
elif torch.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, np.ndarray):
|
||||
if np.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, MultiGroupBlockTable):
|
||||
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
||||
_compare_objs(a_i, b_i)
|
||||
is_same = True
|
||||
elif isinstance(a, (BlockTable, SamplingMetadata, PoolingMetadata)):
|
||||
_compare_objs(a, b)
|
||||
is_same = True # if we make it here must be same
|
||||
elif a == b:
|
||||
is_same = True
|
||||
assert is_same, f"Attribute {attr_name} is different"\
|
||||
f" in {obj1} and {obj2}: {a} != {b}"
|
||||
|
||||
|
||||
def _remove_requests(input_batch: InputBatch, batch_size: int,
|
||||
reqs: list[CachedRequestState]) -> set[str]:
|
||||
"""
|
||||
Remove some requests randomly from the batch and returns
|
||||
set of request removed
|
||||
"""
|
||||
|
||||
num_reqs_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove: set[int] = set()
|
||||
for _ in range(num_reqs_to_remove):
|
||||
req_index_to_remove = np.random.randint(0, batch_size)
|
||||
req_indices_to_remove.add(req_index_to_remove)
|
||||
|
||||
req_ids_to_remove: set[str] = set()
|
||||
for index in req_indices_to_remove:
|
||||
input_batch.remove_request(reqs[index].req_id)
|
||||
req_ids_to_remove.add(reqs[index].req_id)
|
||||
return req_ids_to_remove
|
||||
|
||||
|
||||
def _construct_expected_sampling_metadata(
|
||||
reqs: list[CachedRequestState],
|
||||
req_ids_retained: set[int],
|
||||
req_id_index_in_input_batch: dict[str, int],
|
||||
device: torch.device,
|
||||
) -> SamplingMetadata:
|
||||
"""
|
||||
Constructs and returns the expected SamplingMetadata for this
|
||||
batch.
|
||||
"""
|
||||
num_reqs = len(req_ids_retained)
|
||||
output_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
|
||||
prompt_token_ids: list[list[int]] = [list() for _ in range(num_reqs)]
|
||||
presence_penalties = [0.0 for _ in range(num_reqs)]
|
||||
frequency_penalties = [0.0 for _ in range(num_reqs)]
|
||||
repetition_penalties = [1.0 for _ in range(num_reqs)]
|
||||
top_k = [0 for _ in range(num_reqs)]
|
||||
top_p = [0.0 for _ in range(num_reqs)]
|
||||
temperature = [0.0 for _ in range(num_reqs)]
|
||||
min_tokens = {}
|
||||
logit_bias = [None] * num_reqs
|
||||
allowed_token_ids_mask = torch.zeros(num_reqs,
|
||||
VOCAB_SIZE,
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
bad_words_token_ids = {}
|
||||
for req in reqs:
|
||||
if req.req_id not in req_ids_retained:
|
||||
continue
|
||||
index_in_input_batch = req_id_index_in_input_batch[req.req_id]
|
||||
output_token_ids[index_in_input_batch] = req.output_token_ids
|
||||
prompt_token_ids[index_in_input_batch] = req.prompt_token_ids
|
||||
presence_penalties[
|
||||
index_in_input_batch] = req.sampling_params.presence_penalty
|
||||
frequency_penalties[index_in_input_batch] = (
|
||||
req.sampling_params.frequency_penalty)
|
||||
repetition_penalties[index_in_input_batch] = (
|
||||
req.sampling_params.repetition_penalty)
|
||||
top_k[index_in_input_batch] = req.sampling_params.top_k
|
||||
top_p[index_in_input_batch] = req.sampling_params.top_p
|
||||
temperature[index_in_input_batch] = req.sampling_params.temperature
|
||||
min_tokens[index_in_input_batch] = (
|
||||
req.sampling_params.min_tokens,
|
||||
req.sampling_params.all_stop_token_ids)
|
||||
logit_bias[index_in_input_batch] = req.sampling_params.logit_bias
|
||||
if req.sampling_params.allowed_token_ids:
|
||||
allowed_token_ids_mask[index_in_input_batch][
|
||||
req.sampling_params.allowed_token_ids] = True
|
||||
if req.sampling_params.bad_words_token_ids:
|
||||
bad_words_token_ids[
|
||||
index_in_input_batch] = req.sampling_params.bad_words_token_ids
|
||||
|
||||
return SamplingMetadata(
|
||||
temperature=torch.tensor(temperature, dtype=torch.float,
|
||||
device=device),
|
||||
all_greedy=False,
|
||||
all_random=True,
|
||||
top_p=None if all(x == 1.0 for x in top_p) else torch.tensor(
|
||||
top_p, dtype=torch.float, device=device),
|
||||
top_k=None if all(x == 0 for x in top_k) else torch.tensor(
|
||||
top_k, dtype=torch.int, device=device),
|
||||
generators={},
|
||||
max_num_logprobs=0,
|
||||
prompt_token_ids=make_tensor_with_pad(
|
||||
prompt_token_ids,
|
||||
pad=VOCAB_SIZE,
|
||||
device=torch.device(device),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
frequency_penalties=torch.tensor(frequency_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
presence_penalties=torch.tensor(presence_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
repetition_penalties=torch.tensor(repetition_penalties,
|
||||
dtype=torch.float,
|
||||
device=device),
|
||||
output_token_ids=output_token_ids,
|
||||
no_penalties=(all(x == 0 for x in presence_penalties)
|
||||
and all(x == 0 for x in frequency_penalties)
|
||||
and all(x == 1 for x in repetition_penalties)),
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=bad_words_token_ids,
|
||||
logitsprocs=LogitsProcessors(),
|
||||
)
|
||||
|
||||
|
||||
def _create_sampling_params():
|
||||
return SamplingParams(
|
||||
top_k=np.random.randint(1, 10),
|
||||
top_p=np.random.uniform(0.0, 1.0),
|
||||
presence_penalty=np.random.uniform(-2.0, 2.0),
|
||||
repetition_penalty=np.random.uniform(0.0, 2.0),
|
||||
frequency_penalty=np.random.uniform(-2.0, 2.0),
|
||||
min_tokens=np.random.randint(1, 10),
|
||||
stop_token_ids=[
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(10))
|
||||
],
|
||||
logit_bias={0: np.random.uniform(-3.0, 3.0)},
|
||||
)
|
||||
|
||||
|
||||
def _construct_cached_request_state(req_id_suffix: int):
|
||||
prompt_token_ids = [
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(0, MAX_PROMPT_SIZE))
|
||||
]
|
||||
output_token_ids = [
|
||||
np.random.randint(0, VOCAB_SIZE)
|
||||
for _ in range(np.random.randint(0, NUM_OUTPUT_TOKENS))
|
||||
]
|
||||
return CachedRequestState(
|
||||
req_id=f"req_id_{req_id_suffix}",
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=_create_sampling_params(),
|
||||
pooling_params=None,
|
||||
mm_kwargs=[],
|
||||
mm_positions=[],
|
||||
block_ids=([], ),
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
mm_hashes=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"])
|
||||
@pytest.mark.parametrize("batch_size", [1, 2, 32, 64])
|
||||
def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
"""
|
||||
Tests the logic for managing sampling metadata in the InputBatch.
|
||||
|
||||
This test involves adding a set of requests to the InputBatch,
|
||||
followed by removing a subset of them. Afterward, the batch is compacted,
|
||||
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
|
||||
Note: Ignore logits processor logic, which is tested separately
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=False,
|
||||
vocab_size=1024,
|
||||
block_sizes=[1],
|
||||
)
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
|
||||
# Add requests
|
||||
for req_index in range(batch_size):
|
||||
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||
assigned_req_index = input_batch.add_request(req)
|
||||
assert req_index == assigned_req_index
|
||||
reqs.append(req)
|
||||
req_id_reqs[req.req_id] = req
|
||||
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||
|
||||
# Remove some requests
|
||||
req_ids_to_remove = _remove_requests(input_batch, batch_size, reqs)
|
||||
req_ids_retained = set(req_id_reqs.keys()) - req_ids_to_remove
|
||||
|
||||
# Compact the input batch
|
||||
input_batch.condense()
|
||||
|
||||
# Generate the sampling metadata
|
||||
sampling_metadata = input_batch._make_sampling_metadata()
|
||||
|
||||
# Create expected output.
|
||||
expected_sampling_metadata = _construct_expected_sampling_metadata(
|
||||
reqs,
|
||||
req_ids_retained,
|
||||
input_batch.req_id_to_index,
|
||||
device=torch.device(device))
|
||||
|
||||
def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
|
||||
return (t1 is None
|
||||
and t2 is None) or (t1 is not None and t2 is not None
|
||||
and torch.allclose(t1, t2))
|
||||
|
||||
# Assert the actual and expected output.
|
||||
assert torch.allclose(expected_sampling_metadata.temperature,
|
||||
sampling_metadata.temperature)
|
||||
assert same(expected_sampling_metadata.top_p, sampling_metadata.top_p)
|
||||
assert same(expected_sampling_metadata.top_k, sampling_metadata.top_k)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.frequency_penalties,
|
||||
sampling_metadata.frequency_penalties,
|
||||
)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.presence_penalties,
|
||||
sampling_metadata.presence_penalties,
|
||||
)
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.repetition_penalties,
|
||||
sampling_metadata.repetition_penalties,
|
||||
)
|
||||
assert torch.allclose(expected_sampling_metadata.prompt_token_ids,
|
||||
sampling_metadata.prompt_token_ids)
|
||||
assert (expected_sampling_metadata.output_token_ids ==
|
||||
sampling_metadata.output_token_ids)
|
||||
assert expected_sampling_metadata.no_penalties == \
|
||||
sampling_metadata.no_penalties
|
||||
if sampling_metadata.allowed_token_ids_mask:
|
||||
assert torch.allclose(
|
||||
expected_sampling_metadata.allowed_token_ids_mask,
|
||||
sampling_metadata.allowed_token_ids_mask)
|
||||
assert expected_sampling_metadata.bad_words_token_ids == \
|
||||
sampling_metadata.bad_words_token_ids
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu"])
|
||||
@pytest.mark.parametrize("batch_size", [32])
|
||||
@pytest.mark.parametrize("swap_list", [((0, 1), )])
|
||||
def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
swap_list: list):
|
||||
"""
|
||||
Tests the logic for managing sampling metadata in the InputBatch.
|
||||
|
||||
This test involves adding a set of requests to the InputBatch,
|
||||
followed by removing a subset of them. Afterward, the batch is compacted,
|
||||
and the `make_sampling_metadata` method is invoked on the batch. The
|
||||
output of `make_sampling_metadata` is then compared against the expected
|
||||
results to ensure correctness.
|
||||
|
||||
Note: Ignore logits processor logic, which is tested separately
|
||||
"""
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=False,
|
||||
vocab_size=1024,
|
||||
block_sizes=[1],
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=False,
|
||||
vocab_size=1024,
|
||||
block_sizes=[1],
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
req_id_output_token_ids = {}
|
||||
# Add requests
|
||||
for req_index in range(batch_size):
|
||||
req: CachedRequestState = _construct_cached_request_state(req_index)
|
||||
assigned_req_index = input_batch.add_request(req)
|
||||
assert assigned_req_index == req_index
|
||||
reqs.append(req)
|
||||
req_id_reqs[req.req_id] = req
|
||||
req_id_output_token_ids[req.req_id] = req.output_token_ids
|
||||
|
||||
reordered_reqs = reqs.copy()
|
||||
for swap_pair in swap_list:
|
||||
reordered_reqs[swap_pair[0]], reordered_reqs[swap_pair[1]] = \
|
||||
reordered_reqs[swap_pair[1]], reordered_reqs[swap_pair[0]]
|
||||
input_batch.swap_states(swap_pair[0], swap_pair[1])
|
||||
|
||||
for req_index in range(batch_size):
|
||||
req = reordered_reqs[req_index]
|
||||
assigned_req_index = ref_input_batch.add_request(req)
|
||||
assert assigned_req_index == req_index
|
||||
|
||||
input_batch.refresh_metadata()
|
||||
ref_input_batch.refresh_metadata()
|
||||
|
||||
_compare_objs(input_batch, ref_input_batch)
|
||||
1143
tests/ut/worker/test_worker_v1.py
Normal file
1143
tests/ut/worker/test_worker_v1.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user