Speed up when having padding tokens in DeepEP (#6175)

This commit is contained in:
fzyzcjy
2025-05-18 07:44:05 +08:00
committed by GitHub
parent e3bed74afb
commit 2716830802
4 changed files with 53 additions and 9 deletions

View File

@@ -31,7 +31,6 @@ if _is_cuda:
if _is_cuda or _is_hip:
from sgl_kernel import topk_softmax
expert_distribution_recorder = ExpertDistributionRecorder()
@@ -99,6 +98,7 @@ def grouped_topk(
topk_group: int = 0,
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -138,7 +138,9 @@ def grouped_topk(
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def biased_grouped_topk_impl(
@@ -151,6 +153,7 @@ def biased_grouped_topk_impl(
topk_group: int = 0,
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
@@ -197,13 +200,25 @@ def biased_grouped_topk_impl(
)
topk_weights = topk_weights / topk_weights_sum
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
def is_power_of_two(n):
return n > 0 and math.log2(n).is_integer()
def _mask_topk_ids_padded_region(
topk_ids: torch.Tensor,
num_token_non_padded: Optional[torch.Tensor] = None,
):
if num_token_non_padded is None:
return
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
topk_ids[indices >= num_token_non_padded, :] = -1
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
@@ -215,6 +230,7 @@ def biased_grouped_topk(
compiled: bool = True,
n_share_experts_fusion: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
):
assert (
routed_scaling_factor is not None
@@ -226,7 +242,7 @@ def biased_grouped_topk(
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
and is_power_of_two(correction_bias.shape[0])
):
return moe_fused_gate(
topk_weights, topk_ids = moe_fused_gate(
gating_output,
correction_bias,
num_expert_group,
@@ -235,6 +251,11 @@ def biased_grouped_topk(
n_share_experts_fusion,
routed_scaling_factor,
)
# TODO will fuse this into kernel, thus use slow manual operation now
torch.compile(
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
)(topk_ids, num_token_non_padded)
return topk_weights, topk_ids
else:
biased_grouped_topk_fn = (
torch.compile(
@@ -253,6 +274,7 @@ def biased_grouped_topk(
topk_group,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
)
@@ -268,6 +290,7 @@ def select_experts(
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
):
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
# DeepSeek V2/V3/R1 series models use grouped_top_k
@@ -284,6 +307,7 @@ def select_experts(
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
@@ -296,8 +320,12 @@ def select_experts(
topk_group=topk_group,
n_share_experts_fusion=n_share_experts_fusion,
routed_scaling_factor=routed_scaling_factor,
num_token_non_padded=num_token_non_padded,
)
elif torch_native and custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk_native"
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -305,6 +333,9 @@ def select_experts(
renormalize=renormalize,
)
elif custom_routing_function is None:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in fused_topk"
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
@@ -312,6 +343,9 @@ def select_experts(
renormalize=renormalize,
)
else:
assert (
num_token_non_padded is None
), "num_token_non_padded is not yet supported in custom_routing_function"
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,