Speed up when having padding tokens in DeepEP (#6175)
This commit is contained in:
@@ -31,7 +31,6 @@ if _is_cuda:
|
||||
if _is_cuda or _is_hip:
|
||||
from sgl_kernel import topk_softmax
|
||||
|
||||
|
||||
expert_distribution_recorder = ExpertDistributionRecorder()
|
||||
|
||||
|
||||
@@ -99,6 +98,7 @@ def grouped_topk(
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -138,7 +138,9 @@ def grouped_topk(
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def biased_grouped_topk_impl(
|
||||
@@ -151,6 +153,7 @@ def biased_grouped_topk_impl(
|
||||
topk_group: int = 0,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||
|
||||
@@ -197,13 +200,25 @@ def biased_grouped_topk_impl(
|
||||
)
|
||||
topk_weights = topk_weights / topk_weights_sum
|
||||
|
||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
return n > 0 and math.log2(n).is_integer()
|
||||
|
||||
|
||||
def _mask_topk_ids_padded_region(
|
||||
topk_ids: torch.Tensor,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if num_token_non_padded is None:
|
||||
return
|
||||
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
|
||||
topk_ids[indices >= num_token_non_padded, :] = -1
|
||||
|
||||
|
||||
def biased_grouped_topk(
|
||||
hidden_states: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
@@ -215,6 +230,7 @@ def biased_grouped_topk(
|
||||
compiled: bool = True,
|
||||
n_share_experts_fusion: int = 0,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
assert (
|
||||
routed_scaling_factor is not None
|
||||
@@ -226,7 +242,7 @@ def biased_grouped_topk(
|
||||
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
||||
and is_power_of_two(correction_bias.shape[0])
|
||||
):
|
||||
return moe_fused_gate(
|
||||
topk_weights, topk_ids = moe_fused_gate(
|
||||
gating_output,
|
||||
correction_bias,
|
||||
num_expert_group,
|
||||
@@ -235,6 +251,11 @@ def biased_grouped_topk(
|
||||
n_share_experts_fusion,
|
||||
routed_scaling_factor,
|
||||
)
|
||||
# TODO will fuse this into kernel, thus use slow manual operation now
|
||||
torch.compile(
|
||||
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
||||
)(topk_ids, num_token_non_padded)
|
||||
return topk_weights, topk_ids
|
||||
else:
|
||||
biased_grouped_topk_fn = (
|
||||
torch.compile(
|
||||
@@ -253,6 +274,7 @@ def biased_grouped_topk(
|
||||
topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
|
||||
|
||||
@@ -268,6 +290,7 @@ def select_experts(
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
torch_native: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
):
|
||||
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
||||
@@ -284,6 +307,7 @@ def select_experts(
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = biased_grouped_topk(
|
||||
@@ -296,8 +320,12 @@ def select_experts(
|
||||
topk_group=topk_group,
|
||||
n_share_experts_fusion=n_share_experts_fusion,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
)
|
||||
elif torch_native and custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk_native"
|
||||
topk_weights, topk_ids = fused_topk_native(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -305,6 +333,9 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
elif custom_routing_function is None:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in fused_topk"
|
||||
topk_weights, topk_ids = fused_topk(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
@@ -312,6 +343,9 @@ def select_experts(
|
||||
renormalize=renormalize,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
num_token_non_padded is None
|
||||
), "num_token_non_padded is not yet supported in custom_routing_function"
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
|
||||
@@ -240,6 +240,7 @@ class CudaGraphRunner:
|
||||
self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
||||
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
||||
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
||||
|
||||
# pipeline parallelism
|
||||
if self.pp_size > 1:
|
||||
@@ -403,6 +404,7 @@ class CudaGraphRunner:
|
||||
else:
|
||||
encoder_lens = None
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
self.num_token_non_padded[...] = num_tokens
|
||||
|
||||
# pipeline parallelism
|
||||
if self.pp_size > 1:
|
||||
@@ -461,6 +463,7 @@ class CudaGraphRunner:
|
||||
spec_info=spec_info,
|
||||
capture_hidden_mode=self.capture_hidden_mode,
|
||||
lora_paths=lora_paths,
|
||||
num_token_non_padded=self.num_token_non_padded,
|
||||
)
|
||||
|
||||
if lora_paths is not None:
|
||||
@@ -556,6 +559,7 @@ class CudaGraphRunner:
|
||||
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||
self.num_token_non_padded[...] = len(forward_batch.input_ids)
|
||||
if forward_batch.seq_lens_cpu is not None:
|
||||
if bs != raw_bs:
|
||||
self.seq_lens_cpu.fill_(1)
|
||||
|
||||
@@ -247,6 +247,7 @@ class ForwardBatch:
|
||||
|
||||
# For padding
|
||||
padded_static_len: int = -1 # -1 if not padded
|
||||
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
|
||||
|
||||
# For Qwen2-VL
|
||||
mrope_positions: torch.Tensor = None
|
||||
@@ -290,6 +291,9 @@ class ForwardBatch:
|
||||
capture_hidden_mode=batch.capture_hidden_mode,
|
||||
input_embeds=batch.input_embeds,
|
||||
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
||||
num_token_non_padded=torch.tensor(
|
||||
len(batch.input_ids), dtype=torch.int32
|
||||
).to(device, non_blocking=True),
|
||||
)
|
||||
|
||||
# For DP attention
|
||||
|
||||
@@ -165,7 +165,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x, forward_mode: Optional[ForwardMode] = None):
|
||||
def forward(self, x, forward_batch: Optional[ForwardBatch] = None):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
@@ -287,12 +287,12 @@ class DeepseekV2MoE(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
@@ -309,8 +309,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
return final_hidden_states
|
||||
|
||||
def forward_deepep(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
forward_mode = forward_batch.forward_mode
|
||||
shared_output = None
|
||||
if (
|
||||
forward_mode is not None
|
||||
@@ -330,6 +331,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
num_expert_group=self.num_expert_group,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
)
|
||||
else:
|
||||
topk_idx = torch.full(
|
||||
@@ -1339,7 +1341,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
and (not self.info.is_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
|
||||
Reference in New Issue
Block a user