[main][Feature]Moe alltoallv communication optimization for unquantized RL training sence (#2088)
It comes from 0.9.1dev
[0.9.1][Feature]Moe alltoallv communication optimization for unquantized
RL training sence & alltoallv support dpo (#1547)
- vLLM version: v0.10.0
- vLLM main:
97608dc276
---------
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: ChenTaoyu-SJTU <ctynb@qq.com>
Signed-off-by: taoxudonghaha <justsheldon@163.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: Shanshan Shen <87969357+shen-shanshan@users.noreply.github.com>
Signed-off-by: leo-pony <nengjunma@outlook.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: whx <56632993+whx-sjtu@users.noreply.github.com>
Co-authored-by: curryliu <99582471+Irving11-BKN@users.noreply.github.com>
Co-authored-by: Li Wang <wangli858794774@gmail.com>
Co-authored-by: TaoYu Chen <ctynb@qq.com>
Co-authored-by: taoxudonghaha <justsheldon@163.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: leo-pony <nengjunma@outlook.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -16,12 +16,14 @@
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -35,6 +37,7 @@ from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
@@ -45,6 +48,8 @@ from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
@@ -273,11 +278,13 @@ def fused_experts_with_mc2(
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
@@ -299,9 +306,6 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
assert len(hidden_states_wrapper) == 1
|
||||
hidden_states = hidden_states_wrapper.pop()
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -329,6 +333,8 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -543,10 +549,7 @@ def fused_experts_with_all2all_buffer(
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states_wrapper = [hidden_states]
|
||||
del hidden_states
|
||||
|
||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||
hidden_states = apply_mlp(hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
@@ -682,6 +685,24 @@ def fused_experts_moge(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_all2allv(
|
||||
token_dispatcher,
|
||||
probs,
|
||||
routing_map,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
):
|
||||
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
|
||||
(share_experts_output, dispatched_input,
|
||||
tokens_per_expert) = (token_dispatcher.token_permutation(
|
||||
hidden_states, probs, routing_map))
|
||||
|
||||
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
|
||||
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
|
||||
return output
|
||||
|
||||
|
||||
def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -1124,6 +1145,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
global_batch_size=self.global_batch_size,
|
||||
expert_map=expert_map,
|
||||
ep_group=get_ep_group())
|
||||
elif fused_moe_state == FusedMoEState.All2AllSeq:
|
||||
token_dispatcher = kwargs.get("token_dispatcher")
|
||||
return fused_experts_with_all2allv(
|
||||
token_dispatcher=token_dispatcher,
|
||||
probs=topk_weights,
|
||||
routing_map=topk_ids,
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
)
|
||||
else:
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -1275,6 +1306,25 @@ class AscendFusedMoE(FusedMoE):
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
|
||||
self.quant_method, AscendUnquantizedFusedMoEMethod):
|
||||
self.reduce_results = False
|
||||
moe_dispatcher_config = (
|
||||
MoEDispatcherConfig().set_num_moe_experts(
|
||||
self.global_num_experts).set_num_local_experts(
|
||||
self.local_num_experts).set_moe_router_topk(
|
||||
top_k).set_group_topk(topk_group).
|
||||
set_num_groups(num_expert_group).set_expert_bias(
|
||||
e_score_correction_bias).set_scaling_factor(1.0).build())
|
||||
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
self.token_dispatchers = [
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
def naive_multicast(self, x: torch.Tensor,
|
||||
cu_tokens_across_dp_cpu: torch.Tensor):
|
||||
@@ -1414,6 +1464,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
mc2_mask=mc2_mask,
|
||||
token_dispatcher=self.token_dispatcher,
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
)
|
||||
@@ -1430,11 +1481,11 @@ class AscendFusedMoE(FusedMoE):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if num_tokens < padding_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.NaiveMulticast:
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
@@ -1491,6 +1542,83 @@ class AscendFusedMoE(FusedMoE):
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance)
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AscendSparseMoeBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}.")
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = (
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size,
|
||||
config.num_experts,
|
||||
bias=False,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.experts = AscendFusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.experts",
|
||||
)
|
||||
|
||||
self.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: Optional[AttentionMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_metadata is None:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
enable_force_load_balance = get_forward_context().in_profile_run
|
||||
is_prefill = get_forward_context().with_prefill
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=self.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=None,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user