[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)

This commit is contained in:
Cheng Wan
2025-06-03 17:48:24 -07:00
committed by GitHub
parent b6d0ce9f78
commit 8a5480528d
14 changed files with 82 additions and 93 deletions

View File

@@ -19,15 +19,15 @@ from sglang.srt.layers.moe.topk import biased_grouped_topk
(512, 16, 8, 16),
],
)
@pytest.mark.parametrize("n_share_experts_fusion", [0, 1, 8, 16])
def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusion):
@pytest.mark.parametrize("num_fused_shared_experts", [0, 1])
def test_moe_fused_gate_combined(seq_length, dtype, params, num_fused_shared_experts):
num_experts, num_expert_group, topk_group, topk = params
torch.manual_seed(seq_length)
tensor = torch.rand((seq_length, num_experts)).to(dtype).cuda()
scores = tensor.clone()
bias = torch.rand(num_experts).to(dtype).cuda()
topk = topk + min(1, n_share_experts_fusion)
topk = topk + min(1, num_fused_shared_experts)
output, indices = moe_fused_gate(
tensor,
@@ -35,7 +35,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
num_expert_group=num_expert_group,
topk_group=topk_group,
topk=topk,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
)
ref_output, ref_indices = biased_grouped_topk(
@@ -47,12 +47,12 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
num_expert_group=num_expert_group,
topk_group=topk_group,
compiled=False,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=2.5,
)
# When n_share_experts_fusion > 0, ignore the comparison of the last topk dimension
if n_share_experts_fusion > 0:
# When num_fused_shared_experts > 0, ignore the comparison of the last topk dimension
if num_fused_shared_experts > 0:
original_indices = indices.clone()
original_ref_indices = ref_indices.clone()
@@ -60,7 +60,7 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
ref_indices = ref_indices[:, :-1]
valid_min = num_experts
valid_max = num_experts + n_share_experts_fusion
valid_max = num_experts + num_fused_shared_experts
shared_indices = original_indices[:, -1]
shared_ref_indices = original_ref_indices[:, -1]
if shared_indices is not None:
@@ -87,11 +87,11 @@ def test_moe_fused_gate_combined(seq_length, dtype, params, n_share_experts_fusi
assert idx_check, (
f"Indices mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
)
assert output_check, (
f"Output mismatch at seq_length {seq_length}, dtype {dtype}, "
f"params {params}, n_share_experts_fusion {n_share_experts_fusion}"
f"params {params}, num_fused_shared_experts {num_fused_shared_experts}"
)