[main][Feature]Moe alltoallv communication optimization for unquantized RL training sence (#2088)

It comes from 0.9.1dev
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized
RL training sence & alltoallv support dpo (#1547)

- vLLM version: v0.10.0
- vLLM main:
97608dc276

---------

Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com>
Signed-off-by: taoxudonghaha <justsheldon@163.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com>
Co-authored-by: curryliu <99582471+Irving11-BKN@users.noreply.github.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: TaoYu Chen <ctynb@qq.com>
Co-authored-by: taoxudonghaha <justsheldon@163.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
weijinqian0
2025-08-02 09:49:10 +08:00
committed by GitHub
parent f0c1f0c828
commit 6e00aed4d5
14 changed files with 1265 additions and 17 deletions

View File

@@ -16,12 +16,14 @@
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -35,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.config import \
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
@@ -45,6 +48,8 @@ from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
@@ -273,11 +278,13 @@ def fused_experts_with_mc2(
return hidden_states, shared_hidden_states
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1) -> torch.Tensor:
def apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj
@@ -299,9 +306,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
hidden_states: output hidden states after MLP.
"""
assert len(hidden_states_wrapper) == 1
hidden_states = hidden_states_wrapper.pop()
w1 = w1.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
@@ -329,6 +333,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
return hidden_states
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -543,10 +549,7 @@ def fused_experts_with_all2all_buffer(
hidden_states = hidden_states[sorted_idx]
group_list_type = 0
hidden_states_wrapper = [hidden_states]
del hidden_states
hidden_states = apply_mlp(hidden_states_wrapper,
hidden_states = apply_mlp(hidden_states,
w1,
w2,
expert_tokens,
@@ -682,6 +685,24 @@ def fused_experts_moge(
return final_hidden_states
def fused_experts_with_all2allv(
token_dispatcher,
probs,
routing_map,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
):
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
(share_experts_output, dispatched_input,
tokens_per_expert) = (token_dispatcher.token_permutation(
hidden_states, probs, routing_map))
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
return output
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -1124,6 +1145,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
global_batch_size=self.global_batch_size,
expert_map=expert_map,
ep_group=get_ep_group())
elif fused_moe_state == FusedMoEState.All2AllSeq:
token_dispatcher = kwargs.get("token_dispatcher")
return fused_experts_with_all2allv(
token_dispatcher=token_dispatcher,
probs=topk_weights,
routing_map=topk_ids,
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
)
else:
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
@@ -1275,6 +1306,25 @@ class AscendFusedMoE(FusedMoE):
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.token_dispatcher = None
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
self.quant_method, AscendUnquantizedFusedMoEMethod):
self.reduce_results = False
moe_dispatcher_config = (
MoEDispatcherConfig().set_num_moe_experts(
self.global_num_experts).set_num_local_experts(
self.local_num_experts).set_moe_router_topk(
top_k).set_group_topk(topk_group).
set_num_groups(num_expert_group).set_expert_bias(
e_score_correction_bias).set_scaling_factor(1.0).build())
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
moe_dispatcher_config)
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
moe_dispatcher_config)
self.token_dispatchers = [
self.token_dispatcher, token_dispatcher1
]
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
@@ -1414,6 +1464,7 @@ class AscendFusedMoE(FusedMoE):
shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None,
mc2_mask=mc2_mask,
token_dispatcher=self.token_dispatcher,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)
@@ -1430,11 +1481,11 @@ class AscendFusedMoE(FusedMoE):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
elif self.dp_size > 1:
if fused_moe_state == FusedMoEState.NaiveMulticast:
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
@@ -1491,6 +1542,83 @@ class AscendFusedMoE(FusedMoE):
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance)
enable_force_load_balance=enable_force_load_balance,
)
return hidden_states
class AscendSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = (
ascend_config.torchair_graph_config.enable_multistream_moe)
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)
self.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)
return hidden_states