rm router logits Improve TTOP 3ms (#1407)
### What this PR does / why we need it?
The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather
communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh
gsm8k accuracy verification
<img width="1809" alt="截屏2025-06-24 21 53 24"
src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b"
/>
- vLLM version: v0.9.2
- vLLM main:
77f77a951e
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -367,6 +367,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
self.rm_router_logits = self.experts.rm_router_logits
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -389,7 +390,9 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits = None
|
||||
if not self.rm_router_logits:
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
experts_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -398,6 +401,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=self.shared_experts,
|
||||
gate=self.gate,
|
||||
replace_allreduce=replace_allreduce)
|
||||
|
||||
hidden_states = (
|
||||
|
||||
@@ -45,7 +45,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
get_all_reduce_merge_state, get_fused_moe_state,
|
||||
is_310p, npu_stream_switch, npu_wait_tensor)
|
||||
get_rm_router_logits_state, is_310p,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
@@ -1148,6 +1149,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.global_redundant_expert_num = 0
|
||||
|
||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
||||
self.rm_router_logits = get_rm_router_logits_state(
|
||||
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
||||
self.all_reduce_merge = get_all_reduce_merge_state(
|
||||
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
||||
|
||||
@@ -1240,7 +1243,9 @@ class AscendFusedMoE(FusedMoE):
|
||||
enable_force_load_balance: bool = False,
|
||||
top_k: Optional[int] = None,
|
||||
shared_experts: Optional[Any] = None,
|
||||
gate=None,
|
||||
replace_allreduce: bool = False):
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
@@ -1277,6 +1282,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
@@ -1289,19 +1295,27 @@ class AscendFusedMoE(FusedMoE):
|
||||
hidden_states,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
else:
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_dp_cpu)
|
||||
|
||||
# Matrix multiply.
|
||||
e_hidden_states = self.quant_method.apply(
|
||||
|
||||
@@ -439,6 +439,22 @@ class FusedMoEState(Enum):
|
||||
NaiveMulticast = 4
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): rm_router_logits
|
||||
# dp>1 will trigger
|
||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if dp_size > 1:
|
||||
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return True
|
||||
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): all_reduce merge
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
||||
|
||||
Reference in New Issue
Block a user