rm router logits Improve TTOP 3ms (#1407)

### What this PR does / why we need it?

The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather
communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh

gsm8k accuracy verification
<img width="1809" alt="截屏2025-06-24 21 53 24"
src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b"
/>



- vLLM version: v0.9.2
- vLLM main:
77f77a951e

---------

Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
ttanzhiqiang
2025-07-11 08:53:17 +08:00
committed by GitHub
parent 0fc9b56d40
commit 9d16c9982e
3 changed files with 43 additions and 9 deletions

View File

@@ -45,7 +45,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_all_reduce_merge_state, get_fused_moe_state,
is_310p, npu_stream_switch, npu_wait_tensor)
get_rm_router_logits_state, is_310p,
npu_stream_switch, npu_wait_tensor)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -1148,6 +1149,8 @@ class AscendFusedMoE(FusedMoE):
self.global_redundant_expert_num = 0
is_deepseek_v3_r1 = self.global_num_experts == 256
self.rm_router_logits = get_rm_router_logits_state(
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
self.all_reduce_merge = get_all_reduce_merge_state(
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
@@ -1240,7 +1243,9 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance: bool = False,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False):
assert self.quant_method is not None
if top_k:
@@ -1277,6 +1282,7 @@ class AscendFusedMoE(FusedMoE):
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
@@ -1289,19 +1295,27 @@ class AscendFusedMoE(FusedMoE):
hidden_states,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = get_dp_group().all_gather(router_logits, 0)
elif fused_moe_state == FusedMoEState.NaiveMulticast:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
else:
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_dp_cpu)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(