[2/N][Pangu][MoE] Remove Pangu Related Code (#5130)

### What this PR does / why we need it?
Remove Pangu Related Code

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
e2e & ut

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
This commit is contained in:
weichen
2025-12-19 09:00:07 +08:00
committed by GitHub
parent 1b47fca0e8
commit ca6f631cba
11 changed files with 8 additions and 1444 deletions

View File

@@ -1,6 +1,5 @@
from unittest.mock import MagicMock, patch
import torch
from vllm.attention.layer import Attention
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
@@ -8,8 +7,7 @@ from vllm.model_executor.layers.linear import LinearBase
from tests.ut.base import TestBase
from vllm_ascend.ops.linear import AscendUnquantizedLinearMethod
from vllm_ascend.quantization.quant_config import (AscendKVCacheMethod,
AscendQuantConfig)
from vllm_ascend.quantization.quant_config import AscendQuantConfig
from vllm_ascend.utils import ASCEND_QUANTIZATION_METHOD
@@ -19,7 +17,6 @@ class TestAscendQuantConfig(TestBase):
self.sample_config = {
"weight": "INT8",
"fa_quant_type": "C8",
"kv_quant_type": "C8",
"layer1.weight": "INT8",
"layer2.weight": "FLOAT",
"fused_layer.weight": "FLOAT",
@@ -115,16 +112,6 @@ class TestAscendQuantConfig(TestBase):
attention_layer, ".attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
with patch("vllm_ascend.quantization.quant_config.get_current_vllm_config", return_value=mock_config), \
patch('vllm_ascend.quantization.quant_config.AscendKVCacheMethod', \
return_value=MagicMock()) as mock_ascend_kvcache:
# Test with kv_quant_type
modified_config = {"kv_quant_type": "C8"}
config = AscendQuantConfig(modified_config)
config.packed_modules_mapping = None
method = config.get_quant_method(attention_layer, "attn")
self.assertIs(method, mock_ascend_kvcache.return_value)
def test_get_quant_method_for_fused_moe(self):
fused_moe_layer = MagicMock(spec=FusedMoE)
fused_moe_layer.moe = MagicMock(spec=FusedMoEConfig)
@@ -169,73 +156,3 @@ class TestAscendQuantConfig(TestBase):
def test_get_scaled_act_names(self):
self.assertEqual(self.ascend_config.get_scaled_act_names(), [])
class TestAscendKVCacheMethod(TestBase):
def setUp(self):
# Setup common test fixtures
self.mock_quant_config = MagicMock(spec=AscendQuantConfig)
self.mock_quant_config.quant_description = {"kv_quant_type": "C8"}
self.prefix = "layer.attn"
# Mock quant_method
self.mock_quant_method = MagicMock()
self.patcher = patch(
'vllm_ascend.quantization.quant_config.get_quant_method')
self.mock_get_quant_method = self.patcher.start()
self.mock_get_quant_method.return_value = self.mock_quant_method
# Create instance
self.kv_cache_method = AscendKVCacheMethod(self.mock_quant_config,
self.prefix)
def tearDown(self):
self.patcher.stop()
def test_create_weights(self):
"""Test create_weights delegates to quant_method."""
mock_layer = MagicMock()
self.kv_cache_method.create_weights(mock_layer)
self.mock_quant_method.create_weights.assert_called_once_with(
mock_layer)
def test_process_weights_after_loading_with_method(self):
"""Test process_weights when quant_method has the method."""
mock_layer = MagicMock()
self.kv_cache_method.process_weights_after_loading(mock_layer)
self.mock_quant_method.process_weights_after_loading.assert_called_once_with(
mock_layer)
def test_process_weights_after_loading_without_method(self):
"""Test process_weights when quant_method lacks the method."""
# Reset mock to remove the method
del self.mock_quant_method.process_weights_after_loading
mock_layer = MagicMock()
# Should not raise exception
self.kv_cache_method.process_weights_after_loading(mock_layer)
def test_apply_delegation(self):
"""Test apply properly delegates to quant_method."""
mock_layer = MagicMock()
mock_query = torch.randn(1, 32, 128)
mock_key = torch.randn(1, 32, 128)
mock_value = torch.randn(1, 32, 128)
mock_kv_cache = MagicMock()
mock_attn_metadata = MagicMock()
mock_scale = 1.0
mock_output = torch.zeros(1, 32, 128)
mock_attn_type = MagicMock()
expected_result = torch.randn(1, 32, 128)
self.mock_quant_method.apply.return_value = expected_result
result = self.kv_cache_method.apply(mock_layer, mock_query, mock_key,
mock_value, mock_kv_cache,
mock_attn_metadata, mock_attn_type,
mock_scale, mock_output)
self.mock_quant_method.apply.assert_called_once_with(
mock_layer, mock_query, mock_key, mock_value, mock_kv_cache,
mock_attn_metadata, mock_attn_type, mock_scale, mock_output)
self.assertTrue(torch.equal(result, expected_result))

View File

@@ -39,18 +39,6 @@ class TestGetQuantMethod(TestBase):
"moe")
self.assertIsInstance(method, cls)
def test_with_fa_quant_type(self):
quant_description = {"fa_quant_type": "C8"}
method = get_quant_method(quant_description, ".attn", "attention")
self.assertIsInstance(
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
def test_with_kv_quant_type(self):
quant_description = {"kv_quant_type": "C8"}
method = get_quant_method(quant_description, ".attn", "attention")
self.assertIsInstance(
method, ASCEND_QUANTIZATION_METHOD_MAP["C8"]["attention"])
def test_invalid_layer_type(self):
quant_description = {"linear_layer.weight": "W8A8"}
with self.assertRaises(NotImplementedError):

View File

@@ -1,16 +1,9 @@
import unittest
from unittest.mock import MagicMock, patch
import torch
from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.fused_moe.experts_selector import (_native_grouped_topk,
select_experts)
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod,
fused_experts, fused_experts_310p,
from vllm_ascend.quantization.w8a8 import (AscendW8A8LinearMethod,
quant_per_tensor)
from vllm_ascend.utils import AscendDeviceType
@@ -194,791 +187,3 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(layer.weight_scale.data.shape, (128, ))
self.assertEqual(layer.weight_offset.data.shape, (128, ))
mock_npu_format_cast.assert_called_once()
class TestAscendW8A8FusedMoEMethod(TestBase):
def setUp(self):
self.moe_method = AscendW8A8FusedMoEMethod()
self.num_experts = 4
self.intermediate_size = 64
self.hidden_size = 128
self.dtype = torch.float32
def test_init(self):
self.assertTrue(self.moe_method.transpose_weight)
def test_get_weight(self):
weights = self.moe_method.get_weight(
num_experts=self.num_experts,
intermediate_size_per_partition=self.intermediate_size,
hidden_sizes=self.hidden_size,
params_dtype=self.dtype)
assert "w13_weight" in weights, f"w13_weight not in {weights}"
assert "w2_weight" in weights, f"w2_weight not in {weights}"
self.assertEqual(
weights["w13_weight"].shape,
(self.num_experts, 2 * self.intermediate_size, self.hidden_size))
self.assertEqual(
weights["w2_weight"].shape,
(self.num_experts, self.hidden_size, self.intermediate_size))
self.assertEqual(weights["w13_weight"].dtype, torch.int8)
self.assertEqual(weights["w2_weight"].dtype, torch.int8)
self.assertFalse(weights["w13_weight"].requires_grad)
self.assertFalse(weights["w2_weight"].requires_grad)
def test_get_dynamic_quant_param(self):
quant_params = self.moe_method.get_dynamic_quant_param(
num_experts=self.num_experts,
intermediate_size_per_partition=self.intermediate_size,
hidden_sizes=self.hidden_size,
params_dtype=self.dtype)
expected_params = [
"w13_weight_scale", "w13_weight_offset", "w2_weight_scale",
"w2_weight_offset", "w2_deq_scale", "w13_deq_scale",
"w2_input_scale", "w13_input_scale", "w2_input_offset",
"w13_input_offset", "quant_bias"
]
for param in expected_params:
assert param in quant_params, f"{param} not in {quant_params}"
# Check some sample shapes
self.assertEqual(quant_params["w13_weight_scale"].shape,
(self.num_experts, 2 * self.intermediate_size, 1))
self.assertEqual(quant_params["w2_input_offset"].shape,
(self.num_experts, 1))
self.assertEqual(quant_params["quant_bias"].shape,
(self.num_experts, self.hidden_size))
@patch('vllm_ascend.quantization.w8a8.select_experts')
@patch('vllm_ascend.quantization.w8a8.fused_experts')
def test_apply_with_other_expert_count(self, mock_fused_experts,
mock_select_experts):
# Setup
mock_layer = MagicMock()
x = torch.randn(32, self.hidden_size)
router_logits = torch.randn(32, 128) # 128 experts
top_k = 2
# Mock return values
mock_select_experts.return_value = (torch.randn(32, top_k),
torch.randint(0, 128, (32, top_k)))
mock_fused_experts.return_value = torch.randn(32, self.hidden_size)
# Test
result = self.moe_method.apply(layer=mock_layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=True,
global_num_experts=128)
# Assertions
mock_select_experts.assert_called_once()
mock_fused_experts.assert_called_once()
self.assertEqual(result.shape, (32, self.hidden_size))
@patch('vllm_ascend.quantization.w8a8.get_ascend_device_type',
return_value=AscendDeviceType._310P)
@patch('vllm_ascend.quantization.w8a8.select_experts')
@patch('vllm_ascend.quantization.w8a8.fused_experts_310p')
def test_apply_is_310p(self, mock_fused_experts_310p, mock_select_experts,
mock_soc_version):
# Setup
mock_layer = MagicMock()
x = torch.randn(32, self.hidden_size)
router_logits = torch.randn(32, 128) # 128 experts
top_k = 2
# Mock return values
mock_select_experts.return_value = (torch.randn(32, top_k),
torch.randint(0, 128, (32, top_k)))
mock_fused_experts_310p.return_value = torch.randn(
32, self.hidden_size)
# Test
result = self.moe_method.apply(layer=mock_layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=True,
global_num_experts=128)
# Assertions
mock_select_experts.assert_called_once()
mock_fused_experts_310p.assert_called_once()
self.assertEqual(result.shape, (32, self.hidden_size))
class TestAscendC8KVCacheMethod(TestBase):
def setUp(self):
self.layer = MagicMock()
self.layer.num_kv_heads = 4
self.layer.head_size = 64
self.layer.num_heads = 8
self.layer._k_scale_float = 1.0
self.layer._v_scale_float = 1.0
self.method = AscendC8KVCacheMethod()
self.attention_type = MagicMock()
self.attention_type.DECODER = "decoder"
self.attention_type.ENCODER = "encoder"
def test_create_weights(self):
AscendC8KVCacheMethod.create_weights(self.layer)
self.layer.register_parameter.assert_any_call("key_antiquant_scale",
unittest.mock.ANY)
self.layer.register_parameter.assert_any_call("value_antiquant_scale",
unittest.mock.ANY)
calls = self.layer.register_parameter.call_args_list
for call in calls:
args, kwargs = call
param = kwargs.get('parameter', args[1] if len(args) > 1 else None)
expected_shape = (self.layer.num_kv_heads * self.layer.head_size, )
self.assertEqual(param.shape, expected_shape)
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType.A3)
def test_process_weights_after_loading_not_310p(self, mock_soc_version):
key_data = torch.ones(4 * 64)
value_data = torch.ones(4 * 64) * 2
self.layer.key_antiquant_scale.data = key_data
self.layer.value_antiquant_scale.data = value_data
self.method.process_weights_after_loading(self.layer)
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
@patch('vllm_ascend.utils.get_ascend_device_type',
return_value=AscendDeviceType._310P)
def test_process_weights_after_loading_is_310p(self, mock_soc_version):
key_data = torch.ones(4 * 64)
value_data = torch.ones(4 * 64) * 2
self.layer.key_antiquant_scale.data = key_data
self.layer.value_antiquant_scale.data = value_data
self.method.process_weights_after_loading(self.layer)
self.assertEqual(self.method.antiquant_scale_comb.shape, (2, 256))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[0] == 1))
self.assertTrue(torch.all(self.method.antiquant_scale_comb[1] == 2))
@patch('torch_npu.npu_scatter_nd_update_')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_decode_only(self, mock_quant, mock_scatter):
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
attn_metadata.seq_lens = [10, 10]
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
attn_metadata.slot_mapping = torch.tensor([0, 1])
attn_metadata.attn_mask = None
block_size = 16
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
kv_cache = (key_cache, value_cache)
mock_quant.side_effect = [key, value]
self.layer.key_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.layer.value_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.method.process_weights_after_loading(self.layer)
expected_output = torch.randn(
num_tokens, self.layer.num_heads * self.layer.head_size)
with patch('torch_npu.npu_incre_flash_attention',
return_value=expected_output):
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata,
self.attention_type.DECODER, 1.0,
output)
self.assertEqual(mock_quant.call_count, 2)
self.assertEqual(mock_scatter.call_count, 2)
self.assertTrue(torch.equal(result, expected_output))
@patch('torch_npu.npu_scatter_nd_update_')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_attn_metadata_without_decode(self, mock_quant,
mock_scatter):
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock(spec=[
'attn_state', 'seq_lens', 'block_tables', 'slot_mapping',
'attn_mask'
])
attn_metadata.attn_state = AscendAttentionState.DecodeOnly
attn_metadata.seq_lens = [10, 10]
attn_metadata.block_tables = torch.tensor([[0, 1], [1, 2]])
attn_metadata.slot_mapping = torch.tensor([0, 1])
attn_metadata.attn_mask = None
block_size = 16
key_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
value_cache = torch.empty(2, block_size, self.layer.num_kv_heads,
self.layer.head_size)
kv_cache = (key_cache, value_cache)
mock_quant.side_effect = [key, value]
self.layer.key_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.layer.value_antiquant_scale.data = torch.ones(
self.layer.num_kv_heads * self.layer.head_size)
self.method.process_weights_after_loading(self.layer)
expected_output = torch.randn(
num_tokens, self.layer.num_heads * self.layer.head_size)
with patch('torch_npu.npu_incre_flash_attention',
return_value=expected_output):
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata,
self.attention_type.DECODER, 1.0,
output)
self.assertEqual(mock_quant.call_count, 2)
self.assertEqual(mock_scatter.call_count, 2)
self.assertTrue(torch.equal(result, expected_output))
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('torch_npu._npu_flash_attention')
def test_apply_prefill_no_cache(self, mock_flash, mock_quant):
"""Test apply method in prefill no-cache mode"""
num_tokens = 2
query = torch.randn(num_tokens,
self.layer.num_heads * self.layer.head_size)
key = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(num_tokens,
self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
attn_metadata.seq_lens = [10, 10]
attn_metadata.attn_mask = torch.ones(2, 2)
kv_cache = (torch.tensor([]), torch.tensor([]))
mock_quant.return_value = key
result = self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata, self.attention_type.DECODER,
1.0, output)
# Check that flash attention was called
mock_flash.assert_called_once()
# Check output shape
self.assertEqual(
result.shape,
(num_tokens, self.layer.num_heads * self.layer.head_size))
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_unsupported_attention_type(self, mock_quant):
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
mock_quant.return_value = key
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.PrefillNoCache
with self.assertRaises(NotImplementedError) as cm:
self.method.apply(self.layer, query, key, value, (None, None),
attn_metadata, self.attention_type.ENCODER, 1.0,
output)
assert "Encoder self-attention" in str(
cm.exception), f"Encoder self-attention not in {str(cm.exception)}"
assert "not implemented" in str(
cm.exception), f"not implemented not in{str(cm.exception)}"
mock_quant.assert_not_called()
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
def test_apply_unsupported_attention_state(self, mock_quant):
"""Test apply with unsupported attention state"""
query = torch.randn(1, self.layer.num_heads * self.layer.head_size)
key = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
value = torch.randn(1, self.layer.num_kv_heads * self.layer.head_size)
output = torch.empty_like(query)
attn_metadata = MagicMock()
attn_metadata.attn_state = AscendAttentionState.PrefillCacheHit
mock_quant.return_value = key
kv_cache = (torch.tensor([]), torch.tensor([]))
with self.assertRaises(NotImplementedError):
self.method.apply(self.layer, query, key, value, kv_cache,
attn_metadata, self.attention_type.DECODER, 1.0,
output)
class TestFusedExperts(TestBase):
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_moe_init_routing_v2')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
@patch('torch_npu.npu_moe_finalize_routing')
def test_fused_experts_with_expert_map(self, mock_finalize, mock_swiglu,
mock_group_matmul,
mock_init_routing,
mock_get_ep_group,
mock_quant_per_tensor):
num_tokens = 32
hidden_size = 128
intermediate_size = 256
num_experts = 4
top_k = 2
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w1_scale = torch.tensor([0.1])
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
w1_input_offset = torch.tensor([0])
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
w2_scale = torch.tensor([0.1])
w2_input_scale = torch.tensor([0.2])
w2_input_offset = torch.tensor([0])
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
expert_map = torch.arange(num_experts)
mock_get_ep_group.return_value.world_size = 8
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_init_routing.return_value = (torch.randn(num_tokens * top_k,
hidden_size),
torch.arange(num_tokens * top_k),
torch.tensor([num_tokens // 2] * 2),
torch.tensor(1.0))
mock_group_matmul.side_effect = [[
torch.randn(num_tokens * top_k, intermediate_size * 2)
], [torch.randn(num_tokens * top_k, hidden_size)]]
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
expected_output = torch.randn(num_tokens, hidden_size)
mock_finalize.return_value = expected_output
output = fused_experts(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w1_input_scale=w1_input_scale,
w1_input_offset=w1_input_offset,
w2=w2,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
w2_input_offset=w2_input_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=expert_map,
)
mock_init_routing.assert_called_once()
self.assertEqual(mock_group_matmul.call_count, 2)
self.assertEqual(output.shape, (num_tokens, hidden_size))
mock_finalize.assert_called_once()
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_grouped_matmul')
@patch('torch_npu.npu_swiglu')
def test_fused_experts_without_expert_map(self, mock_swiglu,
mock_group_matmul,
mock_get_ep_group,
mock_quant_per_tensor):
num_tokens = 16
hidden_size = 64
intermediate_size = 128
num_experts = 8
top_k = 1
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
mock_get_ep_group.return_value.world_size = 8
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_group_matmul.side_effect = [[
torch.randn(num_tokens * top_k, intermediate_size * 2)
], [torch.randn(num_tokens * top_k, hidden_size)]]
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
with self.assertRaises(NotImplementedError):
fused_experts(
hidden_states=hidden_states,
w1=w1,
w1_scale=torch.tensor([0.1]),
w1_input_scale=torch.tensor([[0.2, 0.2], [0.2, 0.2]]),
w1_input_offset=torch.tensor([0]),
w2=w2,
w2_scale=torch.tensor([0.1]),
w2_input_scale=torch.tensor([0.1]),
w2_input_offset=torch.tensor([0]),
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=None,
)
class TestFusedExperts310(TestBase):
@patch('torch_npu.npu_quant_grouped_matmul_dequant')
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch('vllm_ascend.quantization.w8a8.get_ep_group')
@patch('torch_npu.npu_swiglu')
def test_fused_experts_310p_with_expert_map(self, mock_swiglu,
mock_get_ep_group,
mock_quant_per_tensor,
mock_matmul_dequant):
num_tokens = 32
hidden_size = 128
intermediate_size = 256
num_experts = 4
top_k = 1
hidden_states = torch.randn(num_tokens, hidden_size)
w1 = torch.randn(num_experts, intermediate_size * 2, hidden_size)
w1_scale = torch.tensor([0.1])
w1_input_scale = torch.tensor([[0.2, 0.2], [0.2, 0.2]])
w2 = torch.randn(num_experts, hidden_size, intermediate_size)
w2_scale = torch.tensor([0.1])
w2_input_scale = torch.tensor([0.2])
topk_weights = torch.rand(num_tokens, top_k)
topk_ids = torch.randint(0, num_experts, (num_tokens, top_k))
expert_map = torch.arange(num_experts)
mock_get_ep_group.return_value.world_size = 1
mock_quant_per_tensor.return_value = torch.randint(-128,
127,
hidden_states.shape,
dtype=torch.int8)
mock_swiglu.return_value = torch.randn(num_tokens * top_k,
intermediate_size)
mock_matmul_dequant.return_value = hidden_states
output = fused_experts_310p(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w1_input_scale=w1_input_scale,
w2=w2,
w2_scale=w2_scale,
w2_input_scale=w2_input_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=num_experts,
expert_map=expert_map,
)
self.assertEqual(output.shape, (num_tokens, hidden_size))
self.assertEqual(mock_matmul_dequant.call_count, 2)
class TestSelectExperts(TestBase):
def setUp(self):
# Common test data
self.num_tokens = 10
self.hidden_size = 32
self.num_experts = 8
self.top_k = 2
self.hidden_states = torch.randn(self.num_tokens, self.hidden_size)
self.router_logits = torch.randn(self.num_tokens, self.num_experts)
"""Mock custom routing"""
self.mock_custom_routing = MagicMock()
self.mock_custom_routing.return_value = (torch.ones(
self.num_tokens, self.top_k),
torch.zeros(
self.num_tokens,
self.top_k,
dtype=torch.int32))
self.mock_ctx = MagicMock()
self.mock_ctx.weight_prefetch_method = MagicMock()
patcher = patch(
'vllm_ascend.ops.fused_moe.experts_selector.get_forward_context',
return_value=self.mock_ctx)
self.addCleanup(patcher.stop)
patcher.start()
@patch('torch_npu.npu_moe_gating_top_k')
def test_softmax_scoring(self, mock_topk):
"""Test softmax scoring function"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="softmax")
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
def test_sigmoid_scoring(self):
"""Test sigmoid scoring function"""
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="sigmoid",
custom_routing_function=self.mock_custom_routing)
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
def test_invalid_scoring_func(self):
"""Test invalid scoring function raises ValueError"""
with self.assertRaises(ValueError):
select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid_func")
@patch('torch.topk')
def test_grouped_topk(self, mock_topk):
"""Test grouped topk functionality"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long))
weights, ids = select_experts(hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2)
mock_topk.assert_called()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk')
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
"""Test grouped topk with expert score correction bias"""
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
self.num_experts)
e_score_correction_bias = torch.randn(self.num_experts)
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=True,
renormalize=False,
topk_group=4,
num_expert_group=2,
e_score_correction_bias=e_score_correction_bias)
mock_grouped_topk.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
def test_custom_routing_function(self):
"""Test custom routing function"""
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
custom_routing_function=self.mock_custom_routing)
self.mock_custom_routing.assert_called_once()
self.assertEqual(weights.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)
@patch('torch_npu.npu_moe_gating_top_k')
def test_renormalize(self, mock_topk):
"""Test renormalization"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.long),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=True,
)
# Check if weights are normalized (sum to 1 for each token)
sums = weights.sum(dim=-1)
self.assertTrue(torch.allclose(sums, torch.ones_like(sums)))
@patch('torch_npu.npu_moe_gating_top_k')
def test_output_dtypes(self, mock_topk):
"""Test output dtypes"""
mock_topk.return_value = (torch.ones(self.num_tokens, self.top_k),
torch.zeros(self.num_tokens,
self.top_k,
dtype=torch.int32),
torch.arange(0,
self.num_tokens * self.top_k,
dtype=torch.int32).view(
self.top_k,
-1).permute(1,
0).contiguous())
weights, ids = select_experts(
hidden_states=self.hidden_states,
router_logits=self.router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=False,
)
self.assertEqual(weights.dtype, self.hidden_states.dtype)
self.assertEqual(ids.dtype, torch.int32)
class TestNativeGroupedTopkPartialMock(TestBase):
def test_basic_group_selection(self):
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1],
[0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3],
[0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4]],
dtype=torch.float32)
expected_topk_indices = torch.tensor([[0, 1], [1, 0], [0, 1], [0, 1]])
with patch('torch.topk',
return_value=(None, expected_topk_indices)) as mock_topk:
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
mock_topk.assert_called_once()
expected_result = topk_weights
self.assertTrue(torch.allclose(result, expected_result))
def test_partial_group_selection(self):
topk_weights = torch.tensor([[0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6],
[0.6, 0.4, 0.7, 0.3, 0.8, 0.2, 0.9, 0.1]])
expected_topk_indices = torch.tensor([[0], [1]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
expected_result = torch.tensor(
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.9, 0.1]])
self.assertTrue(torch.allclose(result, expected_result))
def test_single_group(self):
topk_weights = torch.tensor([[0.1, 0.9, 0.2], [0.8, 0.3, 0.7]])
expected_topk_indices = torch.tensor([[0], [0]])
with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
self.assertTrue(result.numel() > 0)