[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -245,6 +245,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model):
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
||||
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
||||
@wait_until_npu_memory_free()
|
||||
def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
||||
short_example_prompts = [
|
||||
"Hello ",
|
||||
@@ -272,6 +273,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep():
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
@patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"})
|
||||
@patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"})
|
||||
@wait_until_npu_memory_free()
|
||||
def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep():
|
||||
short_example_prompts = [
|
||||
"Hello ",
|
||||
|
||||
@@ -28,11 +28,17 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (
|
||||
check_npu_moe_gating_top_k, select_experts)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import check_npu_moe_gating_top_k, select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import \
|
||||
TokenDispatcherWithAllGather
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
build_fused_experts_input,
|
||||
build_mlp_compute_input,
|
||||
MoEQuantParams,
|
||||
MoERoutingParams,
|
||||
MoETokenDispatchInput,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
NUM_EXPERTS = [8, 64]
|
||||
EP_SIZE = [1]
|
||||
@@ -83,10 +89,8 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map):
|
||||
for i in range(w1.shape[0]):
|
||||
mask = topk_ids == i
|
||||
if mask.sum():
|
||||
out[mask] = SiluAndMul()(
|
||||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) *
|
||||
topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1)
|
||||
return (out.view(B, -1, w2.shape[1]) * topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
|
||||
@@ -129,36 +133,41 @@ def test_token_dispatcher_with_all_gather(
|
||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||
|
||||
apply_router_weight_on_input = False
|
||||
dispatch_output = dispatcher.token_dispatch(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
token_dispatch_output = dispatcher.token_dispatch(
|
||||
token_dispatch_input=MoETokenDispatchInput(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
routing=MoERoutingParams(
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=0,
|
||||
mc2_mask=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
),
|
||||
quant=MoEQuantParams(quant_type=QuantType.NONE),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_hidden_states = dispatch_output.hidden_states
|
||||
group_list = dispatch_output.group_list
|
||||
group_list_type = dispatch_output.group_list_type
|
||||
context_metadata = dispatch_output.context_metadata
|
||||
sorted_hidden_states = token_dispatch_output.hidden_states
|
||||
group_list = token_dispatch_output.group_list
|
||||
group_list_type = token_dispatch_output.group_list_type
|
||||
combine_metadata = token_dispatch_output.combine_metadata
|
||||
|
||||
expert_output = apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
w2=w2_local,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
expert_output = apply_mlp(
|
||||
hidden_states=sorted_hidden_states,
|
||||
w1=w1_local,
|
||||
w2=w2_local,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
)
|
||||
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
|
||||
)
|
||||
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk,
|
||||
expert_map)
|
||||
torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map)
|
||||
|
||||
torch.testing.assert_close(combined_output.routed_out,
|
||||
torch_output,
|
||||
atol=4e-2,
|
||||
rtol=1)
|
||||
torch.testing.assert_close(combined_output, torch_output, atol=4e-2, rtol=1)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -184,8 +193,7 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
):
|
||||
context_mock = MagicMock()
|
||||
context_mock.fused_moe_state = 0
|
||||
with patch("vllm_ascend.ascend_forward_context.get_forward_context",
|
||||
return_value=context_mock):
|
||||
with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", return_value=context_mock):
|
||||
a = torch.randn((m, k), device=device, dtype=dtype) / 10
|
||||
w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8)
|
||||
w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype)
|
||||
@@ -208,34 +216,44 @@ def test_token_dispatcher_with_all_gather_quant(
|
||||
dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs)
|
||||
|
||||
apply_router_weight_on_input = False
|
||||
dispatch_output = dispatcher.token_dispatch(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=True)
|
||||
token_dispatch_output = dispatcher.token_dispatch(
|
||||
token_dispatch_input=MoETokenDispatchInput(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
routing=MoERoutingParams(
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=0,
|
||||
mc2_mask=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
),
|
||||
quant=MoEQuantParams(quant_type=QuantType.W8A8),
|
||||
)
|
||||
)
|
||||
|
||||
sorted_hidden_states = dispatch_output.hidden_states
|
||||
group_list = dispatch_output.group_list
|
||||
group_list_type = dispatch_output.group_list_type
|
||||
dynamic_scale = dispatch_output.dynamic_scale
|
||||
context_metadata = dispatch_output.context_metadata
|
||||
combine_metadata = token_dispatch_output.combine_metadata
|
||||
|
||||
expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=dynamic_scale,
|
||||
with_quant=True)
|
||||
mlp_compute_input = build_mlp_compute_input(
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=a,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
quant_type=QuantType.W8A8,
|
||||
dynamic_eplb=False,
|
||||
expert_map=expert_map,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
),
|
||||
token_dispatch_output=token_dispatch_output,
|
||||
use_fusion_ops=False,
|
||||
)
|
||||
expert_output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||
combined_output = dispatcher.token_combine(
|
||||
hidden_states=expert_output,
|
||||
context_metadata=context_metadata,
|
||||
bias=None)
|
||||
assert combined_output.routed_out.shape == (m, k)
|
||||
hidden_states=expert_output, combine_metadata=combine_metadata, bias=None
|
||||
)
|
||||
assert combined_output.shape == (m, k)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
@@ -271,25 +289,20 @@ def test_select_experts(
|
||||
hidden_states = torch.randn(m, n, device=device, dtype=dtype)
|
||||
router_logits = torch.randn(m, e, device=device, dtype=dtype)
|
||||
|
||||
e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
|
||||
if with_e_correction else None)
|
||||
e_score_correction_bias = torch.randn(e, device=device, dtype=dtype) if with_e_correction else None
|
||||
|
||||
custom_routing_function = None
|
||||
if custom_routing:
|
||||
custom_routing_function = MagicMock()
|
||||
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
|
||||
mock_ids = torch.randint(0,
|
||||
e, (m, topk),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
mock_ids = torch.randint(0, e, (m, topk), device=device, dtype=torch.int32)
|
||||
custom_routing_function.return_value = (mock_weights, mock_ids)
|
||||
|
||||
with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk"
|
||||
) as mock_native_grouped_topk, \
|
||||
patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||
return_value=MagicMock()):
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
|
||||
x)
|
||||
with (
|
||||
patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk") as mock_native_grouped_topk,
|
||||
patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
|
||||
):
|
||||
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(x)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -305,8 +318,8 @@ def test_select_experts(
|
||||
)
|
||||
|
||||
call_moe_gatingtopk = check_npu_moe_gating_top_k(
|
||||
hidden_states, topk, renormalize, topk_group, num_expert_group,
|
||||
scoring_func, custom_routing_function)
|
||||
hidden_states, topk, renormalize, topk_group, num_expert_group, scoring_func, custom_routing_function
|
||||
)
|
||||
if not call_moe_gatingtopk and use_grouped_topk:
|
||||
mock_native_grouped_topk.assert_called_once()
|
||||
else:
|
||||
@@ -323,16 +336,18 @@ def test_select_experts(
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICE)
|
||||
def test_select_experts_invalid_scoring_func(device: str):
|
||||
with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method',
|
||||
return_value=MagicMock()), \
|
||||
pytest.raises(ValueError,
|
||||
match="Unsupported scoring function: invalid"):
|
||||
select_experts(hidden_states=torch.randn(1, 128, device=device),
|
||||
router_logits=torch.randn(1, 8, device=device),
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid")
|
||||
with (
|
||||
patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()),
|
||||
pytest.raises(ValueError, match="Unsupported scoring function: invalid"),
|
||||
):
|
||||
select_experts(
|
||||
hidden_states=torch.randn(1, 128, device=device),
|
||||
router_logits=torch.randn(1, 8, device=device),
|
||||
top_k=2,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
scoring_func="invalid",
|
||||
)
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
Reference in New Issue
Block a user