[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -245,6 +245,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
@wait_until_npu_memory_free()
def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
short_example_prompts = [
"Hello ",
@@ -272,6 +273,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
@wait_until_npu_memory_free()
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
short_example_prompts = [
"Hello ",

View File

@@ -28,11 +28,17 @@ import torch
import torch_npu
from vllm.model_executor.layers.activation import SiluAndMul
from vllm_ascend.ops.fused_moe.experts_selector import (
check_npu_moe_gating_top_k, select_experts)
from vllm_ascend.ops.fused_moe.experts_selector import check_npu_moe_gating_top_k, select_experts
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.token_dispatcher import \
TokenDispatcherWithAllGather
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
build_fused_experts_input,
build_mlp_compute_input,
MoEQuantParams,
MoERoutingParams,
MoETokenDispatchInput,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather
from vllm_ascend.quantization.quant_type import QuantType
NUM_EXPERTS = [8, 64]
EP_SIZE = [1]
@@ -83,10 +89,8 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
out[mask] = SiluAndMul()(
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) *
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
return (out.view(B, -1, w2.shape[1]) * topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@@ -129,36 +133,41 @@ def test_token_dispatcher_with_all_gather(
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
token_dispatch_output = dispatcher.token_dispatch(
token_dispatch_input=MoETokenDispatchInput(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=0,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
),
quant=MoEQuantParams(quant_type=QuantType.NONE),
)
)
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
context_metadata = dispatch_output.context_metadata
sorted_hidden_states = token_dispatch_output.hidden_states
group_list = token_dispatch_output.group_list
group_list_type = token_dispatch_output.group_list_type
combine_metadata = token_dispatch_output.combine_metadata
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
w1=w1_local,
w2=w2_local,
group_list=group_list,
group_list_type=group_list_type)
expert_output = apply_mlp(
hidden_states=sorted_hidden_states,
w1=w1_local,
w2=w2_local,
group_list=group_list,
group_list_type=group_list_type,
)
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
expert_map)
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map)
torch.testing.assert_close(combined_output.routed_out,
torch_output,
atol=4e-2,
rtol=1)
torch.testing.assert_close(combined_output, torch_output, atol=4e-2, rtol=1)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@@ -184,8 +193,7 @@ def test_token_dispatcher_with_all_gather_quant(
):
context_mock = MagicMock()
context_mock.fused_moe_state = 0
with patch("vllm_ascend.ascend_forward_context.get_forward_context",
return_value=context_mock):
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", return_value=context_mock):
a = torch.randn((m, k), device=device, dtype=dtype) / 10
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
@@ -208,34 +216,44 @@ def test_token_dispatcher_with_all_gather_quant(
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
apply_router_weight_on_input = False
dispatch_output = dispatcher.token_dispatch(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=True)
token_dispatch_output = dispatcher.token_dispatch(
token_dispatch_input=MoETokenDispatchInput(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=0,
mc2_mask=None,
apply_router_weight_on_input=apply_router_weight_on_input,
),
quant=MoEQuantParams(quant_type=QuantType.W8A8),
)
)
sorted_hidden_states = dispatch_output.hidden_states
group_list = dispatch_output.group_list
group_list_type = dispatch_output.group_list_type
dynamic_scale = dispatch_output.dynamic_scale
context_metadata = dispatch_output.context_metadata
combine_metadata = token_dispatch_output.combine_metadata
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
group_list_type=group_list_type,
dynamic_scale=dynamic_scale,
with_quant=True)
mlp_compute_input = build_mlp_compute_input(
fused_experts_input=build_fused_experts_input(
hidden_states=a,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=w1,
w2=w2,
quant_type=QuantType.W8A8,
dynamic_eplb=False,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
),
token_dispatch_output=token_dispatch_output,
use_fusion_ops=False,
)
expert_output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
combined_output = dispatcher.token_combine(
hidden_states=expert_output,
context_metadata=context_metadata,
bias=None)
assert combined_output.routed_out.shape == (m, k)
hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
)
assert combined_output.shape == (m, k)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
@@ -271,25 +289,20 @@ def test_select_experts(
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype)
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
if with_e_correction else None)
e_score_correction_bias = torch.randn(e, device=device, dtype=dtype) if with_e_correction else None
custom_routing_function = None
if custom_routing:
custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0,
e, (m, topk),
device=device,
dtype=torch.int32)
mock_ids = torch.randint(0, e, (m, topk), device=device, dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk, \
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
with (
patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk") as mock_native_grouped_topk,
patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
):
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(x)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
@@ -305,8 +318,8 @@ def test_select_experts(
)
call_moe_gatingtopk = check_npu_moe_gating_top_k(
hidden_states, topk, renormalize, topk_group, num_expert_group,
scoring_func, custom_routing_function)
hidden_states, topk, renormalize, topk_group, num_expert_group, scoring_func, custom_routing_function
)
if not call_moe_gatingtopk and use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
@@ -323,16 +336,18 @@ def test_select_experts(
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
return_value=MagicMock()), \
pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid")
with (
patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
pytest.raises(ValueError, match="Unsupported scoring function: invalid"),
):
select_experts(
hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid",
)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()