rm router logits Improve TTOP 3ms (#1407)

### What this PR does / why we need it?

The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather
communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh

gsm8k accuracy verification
<img width="1809" alt="截屏2025-06-24 21 53 24"
src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b"
/>



- vLLM version: v0.9.2
- vLLM main:
77f77a951e

---------

Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
ttanzhiqiang
2025-07-11 08:53:17 +08:00
committed by GitHub
parent 0fc9b56d40
commit 9d16c9982e
3 changed files with 43 additions and 9 deletions

View File

@@ -367,6 +367,7 @@ class CustomDeepseekV2MoE(nn.Module):
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
self.rm_router_logits = self.experts.rm_router_logits
def forward(self,
hidden_states: torch.Tensor,
@@ -389,7 +390,9 @@ class CustomDeepseekV2MoE(nn.Module):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = None
if not self.rm_router_logits:
router_logits, _ = self.gate(hidden_states)
experts_hidden_states = self.experts(
hidden_states=hidden_states,
@@ -398,6 +401,7 @@ class CustomDeepseekV2MoE(nn.Module):
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=self.shared_experts,
gate=self.gate,
replace_allreduce=replace_allreduce)
hidden_states = (