[main] adapt usage of npu_moe_gating_top_k_softmax and remove envs.SELECT_GATING_TOPK_SOTFMAX_EXPERTS (#2112)

backport of v0.9.1-dev:
https://github.com/vllm-project/vllm-ascend/pull/1902

origin main npu_moe_gating_top_k_softmax:
https://github.com/vllm-project/vllm-ascend/pull/1355

- vLLM version: v0.10.0
- vLLM main:
055bd3978e

Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
huangxialu
2025-07-31 21:05:56 +08:00
committed by GitHub
parent e8660d7978
commit 9c9a7cd90b
5 changed files with 146 additions and 89 deletions

View File

@@ -23,11 +23,13 @@ Run `pytest tests/ops/test_fused_moe.py`.
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa
from unittest.mock import MagicMock, patch
import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe import fused_experts
from vllm_ascend.ops.fused_moe import fused_experts, select_experts
NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
@@ -98,3 +100,97 @@ def test_fused_experts(
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
torch.npu.empty_cache()
@pytest.mark.parametrize("m", [1, 33, 64])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("use_grouped_topk", [True, False])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("with_e_correction", [True, False])
@pytest.mark.parametrize("custom_routing", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts(
m: int,
n: int,
e: int,
topk: int,
scoring_func: str,
use_grouped_topk: bool,
renormalize: bool,
with_e_correction: bool,
custom_routing: bool,
dtype: torch.dtype,
device: str,
):
topk_group = 4 if use_grouped_topk else None
num_expert_group = e // 4 if use_grouped_topk else None
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype)
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
if with_e_correction else None)
custom_routing_function = None
if custom_routing:
custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0,
e, (m, topk),
device=device,
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
) as mock_native_grouped_topk:
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
if use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()
assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid")
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")

View File

@@ -297,9 +297,8 @@ class TestAscendUnquantizedFusedMoEMethod:
assert not layer.w13_weight.requires_grad
assert not layer.w2_weight.requires_grad
@pytest.mark.parametrize(
"others_param",
[[256, 4, False], [128, 1, False], [128, 1, True], [128, 4, False]])
@pytest.mark.parametrize("others_param",
[[256, 4], [128, 1], [128, 1], [128, 4]])
def test_apply_without_expert_map(self, moe_method, mock_dist_env,
mock_moe_env, others_param):
"""
@@ -308,15 +307,13 @@ class TestAscendUnquantizedFusedMoEMethod:
3 test use select_gating_topk_softmax_experts and fused_experts
4 test use select_experts and fused_experts_with_all2all_buffer
"""
global_num_experts, ep_size, select_softmax = others_param
global_num_experts, ep_size = others_param
is_prefill = False
is_deepseek_v3_r1 = global_num_experts == 256
forward_context = MagicMock(fused_moe_state=get_fused_moe_state(
ep_size, is_prefill, is_deepseek_v3_r1))
with patch(
"vllm_ascend.ops.fused_moe.SELECT_GATING_TOPK_SOTFMAX_EXPERTS",
select_softmax), \
patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context):
with patch("vllm_ascend.ops.fused_moe.get_forward_context",
return_value=forward_context):
moe_method.ep_size = ep_size
x = torch.randn(8, 2, 2)
router_logits = torch.randn(8, 8)