[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)

This commit is contained in:
Cheng Wan
2025-06-03 17:48:24 -07:00
committed by GitHub
parent b6d0ce9f78
commit 8a5480528d
14 changed files with 82 additions and 93 deletions

View File

@@ -103,7 +103,7 @@ def grouped_topk(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -128,10 +128,10 @@ def grouped_topk(
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if n_share_experts_fusion:
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
@@ -141,7 +141,7 @@ def grouped_topk(
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
@@ -160,7 +160,7 @@ def biased_grouped_topk_impl(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
n_share_experts_fusion: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -192,10 +192,10 @@ def biased_grouped_topk_impl(
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)
if n_share_experts_fusion:
if num_fused_shared_experts:
topk_ids[:, -1] = torch.randint(
low=num_experts,
high=num_experts + n_share_experts_fusion,
high=num_experts + num_fused_shared_experts,
size=(topk_ids.size(0),),
dtype=topk_ids.dtype,
device=topk_ids.device,
@@ -205,7 +205,7 @@ def biased_grouped_topk_impl(
if renormalize:
topk_weights_sum = (
topk_weights.sum(dim=-1, keepdim=True)
if n_share_experts_fusion == 0
if num_fused_shared_experts == 0
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
)
topk_weights = topk_weights / topk_weights_sum
@@ -239,7 +239,7 @@ def biased_grouped_topk(
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
n_share_experts_fusion: int = 0,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
@@ -247,7 +247,7 @@ def biased_grouped_topk(
assert (
routed_scaling_factor is not None
), "routed_scaling_factor is required for biased_grouped_topk"
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
if (
_is_cuda
and gating_output.shape[1] // num_expert_group
@@ -260,7 +260,7 @@ def biased_grouped_topk(
num_expert_group,
topk_group,
topk,
n_share_experts_fusion,
num_fused_shared_experts,
routed_scaling_factor,
)
# TODO merge into kernel for this branch
@@ -288,7 +288,7 @@ def biased_grouped_topk(
renormalize,
num_expert_group,
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
@@ -310,7 +310,7 @@ def select_experts(
num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
):
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
router_logits, correction_bias = (
expert_location_dispatch.transform_select_experts_inputs(
@@ -332,7 +332,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,
@@ -346,7 +346,7 @@ def select_experts(
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
num_fused_shared_experts=num_fused_shared_experts,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
expert_location_dispatch_info=expert_location_dispatch_info,