[Refactor] Rename n_share_experts_fusion as num_fused_shared_experts (#6735)
This commit is contained in:
@@ -103,7 +103,7 @@ def grouped_topk(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
@@ -128,10 +128,10 @@ def grouped_topk(
|
||||
) # [n, e]
|
||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
if n_share_experts_fusion:
|
||||
if num_fused_shared_experts:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
high=num_experts + n_share_experts_fusion,
|
||||
high=num_experts + num_fused_shared_experts,
|
||||
size=(topk_ids.size(0),),
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
@@ -141,7 +141,7 @@ def grouped_topk(
|
||||
if renormalize:
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if n_share_experts_fusion == 0
|
||||
if num_fused_shared_experts == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
@@ -160,7 +160,7 @@ def biased_grouped_topk_impl(
|
||||
renormalize: bool,
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
@@ -192,10 +192,10 @@ def biased_grouped_topk_impl(
|
||||
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||
topk_weights = scores.gather(1, topk_ids)
|
||||
|
||||
if n_share_experts_fusion:
|
||||
if num_fused_shared_experts:
|
||||
topk_ids[:, -1] = torch.randint(
|
||||
low=num_experts,
|
||||
high=num_experts + n_share_experts_fusion,
|
||||
high=num_experts + num_fused_shared_experts,
|
||||
size=(topk_ids.size(0),),
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
@@ -205,7 +205,7 @@ def biased_grouped_topk_impl(
|
||||
if renormalize:
|
||||
topk_weights_sum = (
|
||||
topk_weights.sum(dim=-1, keepdim=True)
|
||||
if n_share_experts_fusion == 0
|
||||
if num_fused_shared_experts == 0
|
||||
else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
@@ -239,7 +239,7 @@ def biased_grouped_topk(
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
num_fused_shared_experts: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
@@ -247,7 +247,7 @@ def biased_grouped_topk(
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
), "routed_scaling_factor is required for biased_grouped_topk"
|
||||
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
||||
# TODO: moe_fused_gate kernel is not supported for num_fused_shared_experts > 0 now.
|
||||
if (
|
||||
_is_cuda
|
||||
and gating_output.shape[1] // num_expert_group
|
||||
@@ -260,7 +260,7 @@ def biased_grouped_topk(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
num_fused_shared_experts,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
# TODO merge into kernel for this branch
|
||||
@@ -288,7 +288,7 @@ def biased_grouped_topk(
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
@@ -310,7 +310,7 @@ def select_experts(
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
num_fused_shared_experts = global_server_args_dict["num_fused_shared_experts"]
|
||||
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
@@ -332,7 +332,7 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
@@ -346,7 +346,7 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
|
||||
Reference in New Issue
Block a user