simplify the control logic for using shared experts fusion (#5504)
This commit is contained in:
@@ -98,6 +98,7 @@ def grouped_topk(
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -127,9 +128,7 @@ def grouped_topk(
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_weights[:, -1] = (
|
||||
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||
) # 2.5 is the routed_scaling_factor.
|
||||
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
||||
|
||||
if renormalize:
|
||||
topk_weights_sum = (
|
||||
@@ -151,6 +150,7 @@ def biased_grouped_topk_impl(
|
||||
num_expert_group: int = 0,
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -187,9 +187,7 @@ def biased_grouped_topk_impl(
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device,
|
||||
)
|
||||
topk_weights[:, -1] = (
|
||||
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
||||
) # 2.5 is the routed_scaling_factor.
|
||||
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
||||
|
||||
if renormalize:
|
||||
topk_weights_sum = (
|
||||
@@ -216,13 +214,16 @@ def biased_grouped_topk(
|
||||
topk_group: int = 0,
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
), "routed_scaling_factor is required for biased_grouped_topk"
|
||||
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
||||
if (
|
||||
_is_cuda
|
||||
and gating_output.shape[1] // num_expert_group
|
||||
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
||||
and n_share_experts_fusion == 0
|
||||
and is_power_of_two(correction_bias.shape[0])
|
||||
):
|
||||
return moe_fused_gate(
|
||||
@@ -231,6 +232,8 @@ def biased_grouped_topk(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
topk,
|
||||
n_share_experts_fusion,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
biased_grouped_topk_fn = (
|
||||
@@ -249,6 +252,7 @@ def biased_grouped_topk(
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
@@ -263,10 +267,9 @@ def select_experts(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
n_share_experts_fusion = 0
|
||||
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
||||
if use_grouped_topk:
|
||||
assert topk_group is not None
|
||||
@@ -280,6 +283,7 @@ def select_experts(
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -291,6 +295,7 @@ def select_experts(
|
||||
num_expert_group=num_expert_group,
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
|
||||
Reference in New Issue
Block a user