rm router logits Improve TTOP 3ms (#1407)
### What this PR does / why we need it?
The previous code is
router_logits, _ = self.gate(hidden_states)
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
I want to change the two all_gathers to one, reduce one all_gather
communication, and make it
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits, _ = self.gate(hidden_states)
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
bash examples/run_dp_attention_etp16.sh
bash examples/run_dp_attention_etp16_benmark.sh
gsm8k accuracy verification
<img width="1809" alt="截屏2025-06-24 21 53 24"
src="https://github.com/user-attachments/assets/47eace3b-a86b-41b4-9de8-773f57fea33b"
/>
- vLLM version: v0.9.2
- vLLM main:
77f77a951e
---------
Signed-off-by: ttanzhiqiang <389825161@qq.com>
This commit is contained in:
@@ -367,6 +367,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
self.ep_group = get_ep_group()
|
self.ep_group = get_ep_group()
|
||||||
|
|
||||||
self.params_dtype = torch.get_default_dtype()
|
self.params_dtype = torch.get_default_dtype()
|
||||||
|
self.rm_router_logits = self.experts.rm_router_logits
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
@@ -389,7 +390,9 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||||
|
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits = None
|
||||||
|
if not self.rm_router_logits:
|
||||||
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
experts_hidden_states = self.experts(
|
experts_hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
@@ -398,6 +401,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
top_k=CustomDeepseekV2MoE.top_k,
|
top_k=CustomDeepseekV2MoE.top_k,
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
shared_experts=self.shared_experts,
|
shared_experts=self.shared_experts,
|
||||||
|
gate=self.gate,
|
||||||
replace_allreduce=replace_allreduce)
|
replace_allreduce=replace_allreduce)
|
||||||
|
|
||||||
hidden_states = (
|
hidden_states = (
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
|||||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||||
get_all_reduce_merge_state, get_fused_moe_state,
|
get_all_reduce_merge_state, get_fused_moe_state,
|
||||||
is_310p, npu_stream_switch, npu_wait_tensor)
|
get_rm_router_logits_state, is_310p,
|
||||||
|
npu_stream_switch, npu_wait_tensor)
|
||||||
|
|
||||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||||
|
|
||||||
@@ -1148,6 +1149,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
self.global_redundant_expert_num = 0
|
self.global_redundant_expert_num = 0
|
||||||
|
|
||||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
is_deepseek_v3_r1 = self.global_num_experts == 256
|
||||||
|
self.rm_router_logits = get_rm_router_logits_state(
|
||||||
|
self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1)
|
||||||
self.all_reduce_merge = get_all_reduce_merge_state(
|
self.all_reduce_merge = get_all_reduce_merge_state(
|
||||||
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
self.moe_parallel_config.ep_size, is_deepseek_v3_r1)
|
||||||
|
|
||||||
@@ -1240,7 +1243,9 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
shared_experts: Optional[Any] = None,
|
shared_experts: Optional[Any] = None,
|
||||||
|
gate=None,
|
||||||
replace_allreduce: bool = False):
|
replace_allreduce: bool = False):
|
||||||
|
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if top_k:
|
if top_k:
|
||||||
@@ -1277,6 +1282,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
tp_rank = get_tensor_model_parallel_rank()
|
tp_rank = get_tensor_model_parallel_rank()
|
||||||
hidden_states = chunk_hidden_states[tp_rank]
|
hidden_states = chunk_hidden_states[tp_rank]
|
||||||
router_logits = chunk_router_logits[tp_rank]
|
router_logits = chunk_router_logits[tp_rank]
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
if fused_moe_state == FusedMoEState.AllGather:
|
if fused_moe_state == FusedMoEState.AllGather:
|
||||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||||
@@ -1289,19 +1295,27 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
(0, 0, 0,
|
(0, 0, 0,
|
||||||
max_num_tokens_across_dp - num_tokens))
|
max_num_tokens_across_dp - num_tokens))
|
||||||
router_logits = nn.functional.pad(
|
if not self.rm_router_logits:
|
||||||
router_logits,
|
router_logits = nn.functional.pad(
|
||||||
(0, 0, 0,
|
router_logits,
|
||||||
max_num_tokens_across_dp - num_tokens))
|
(0, 0, 0,
|
||||||
|
max_num_tokens_across_dp - num_tokens))
|
||||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
if self.rm_router_logits:
|
||||||
|
router_logits, _ = gate(hidden_states)
|
||||||
|
else:
|
||||||
|
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||||
|
|
||||||
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
elif fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||||
cu_tokens_across_dp_cpu = get_forward_context(
|
cu_tokens_across_dp_cpu = get_forward_context(
|
||||||
).dp_metadata.cu_tokens_across_dp_cpu
|
).dp_metadata.cu_tokens_across_dp_cpu
|
||||||
hidden_states = self.naive_multicast(hidden_states,
|
hidden_states = self.naive_multicast(hidden_states,
|
||||||
cu_tokens_across_dp_cpu)
|
cu_tokens_across_dp_cpu)
|
||||||
router_logits = self.naive_multicast(router_logits,
|
if self.rm_router_logits:
|
||||||
cu_tokens_across_dp_cpu)
|
router_logits, _ = gate(hidden_states)
|
||||||
|
else:
|
||||||
|
router_logits = self.naive_multicast(
|
||||||
|
router_logits, cu_tokens_across_dp_cpu)
|
||||||
|
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
e_hidden_states = self.quant_method.apply(
|
e_hidden_states = self.quant_method.apply(
|
||||||
|
|||||||
@@ -439,6 +439,22 @@ class FusedMoEState(Enum):
|
|||||||
NaiveMulticast = 4
|
NaiveMulticast = 4
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(ttanzhiqiang): rm_router_logits
|
||||||
|
# dp>1 will trigger
|
||||||
|
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||||
|
def get_rm_router_logits_state(ep_size: int, dp_size: int,
|
||||||
|
is_deepseek_v3_r1: bool):
|
||||||
|
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||||
|
# only supports deepseek v3/r1
|
||||||
|
if dp_size > 1:
|
||||||
|
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||||
|
and is_deepseek_v3_r1):
|
||||||
|
return True
|
||||||
|
elif ep_size == 1 and is_deepseek_v3_r1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
# TODO(ttanzhiqiang): all_reduce merge
|
# TODO(ttanzhiqiang): all_reduce merge
|
||||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||||
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
# Currently, all_reduce_merge is enabled by default in the AllGather, AllGatherEP and NaiveMulticast scenarios of the deepseek model.
|
||||||
|
|||||||
Reference in New Issue
Block a user