rm router logits Improve TTOP 3ms (#1407)
### What this PR does / why we need it?
The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather
communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh
gsm8k accuracy verification
<img width="1809" alt="截屏2025-06-24 21 53 24"
src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b"
/>
- vLLM version: v0.9.2
- vLLM main:
77f77a951e
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -367,6 +367,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
self.rm_router_logits = self.experts.rm_router_logits
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -389,7 +390,9 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits = None
|
||||
if not self.rm_router_logits:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
experts_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -398,6 +401,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=self.shared_experts,
|
||||
gate=self.gate,
|
||||
replace_allreduce=replace_allreduce)
|
||||
|
||||
hidden_states = (
|
||||
|
||||
Reference in New Issue
Block a user