Speed up when having padding tokens in DeepEP (#6175)
This commit is contained in:
@@ -31,7 +31,6 @@ if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
@@ -99,6 +98,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -138,7 +138,9 @@ def grouped_topk(
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def biased_grouped_topk_impl(
|
||||
@@ -151,6 +153,7 @@ def biased_grouped_topk_impl(
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -197,13 +200,25 @@ def biased_grouped_topk_impl(
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return n > 0 and math.log2(n).is_integer()
|
||||
|
||||
|
||||
def _mask_topk_ids_padded_region(
|
||||
topk_ids: torch.Tensor,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if num_token_non_padded is None:
|
||||
return
|
||||
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
|
||||
topk_ids[indices >= num_token_non_padded, :] = -1
|
||||
|
||||
|
||||
def biased_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -215,6 +230,7 @@ def biased_grouped_topk(
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
@@ -226,7 +242,7 @@ def biased_grouped_topk(
|
||||
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
||||
and is_power_of_two(correction_bias.shape[0])
|
||||
):
|
||||
return moe_fused_gate(
|
||||
topk_weights, topk_ids = moe_fused_gate(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
num_expert_group,
|
||||
@@ -235,6 +251,11 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
# TODO will fuse this into kernel, thus use slow manual operation now
|
||||
torch.compile(
|
||||
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
||||
)(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
else:
|
||||
biased_grouped_topk_fn = (
|
||||
torch.compile(
|
||||
@@ -253,6 +274,7 @@ def biased_grouped_topk(
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
|
||||
|
||||
@@ -268,6 +290,7 @@ def select_experts(
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
||||
@@ -284,6 +307,7 @@ def select_experts(
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -296,8 +320,12 @@ def select_experts(
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -305,6 +333,9 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk"
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -312,6 +343,9 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
Reference in New Issue
Block a user