support fused_moe_allgather_ep (#1335)

### What this PR does / why we need it?
support fused_moe_allgather_ep

### How was this patch tested?
It was tested by UT.

Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
This commit is contained in:
lyj-jjj
2025-06-23 22:03:38 +08:00
committed by GitHub
parent 917c6b71af
commit 5177bef87a
5 changed files with 218 additions and 14 deletions

View File

@@ -0,0 +1,82 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep.
Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
"""
import os
from unittest.mock import patch
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from tests.conftest import VllmRunner
@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1",
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
})
def test_generate_with_allgather():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=16,
enforce_eager=True,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"
})
def test_generate_with_alltoall():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
tensor_parallel_size=16,
enforce_eager=True,
max_model_len=1024,
dtype="auto",
enable_expert_parallel=True,
additional_config={
"ascend_scheduler_config": {
"enabled": True,
"chunked_prefill_enabled": False,
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -99,6 +99,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to enable the trace recompiles from pytorch.
"VLLM_ASCEND_TRACE_RECOMPILES":
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
),
"VLLM_ASCEND_ENABLE_DBO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
# Whether to enable the model execute time observe profile. Disable it when

View File

@@ -988,8 +988,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
**kwargs,
) -> torch.Tensor:
is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
@@ -1025,7 +1026,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
is_prefill, is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
@@ -1219,15 +1220,17 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
is_deepseek_v3_r1 = self.global_num_experts == 256
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
is_prefill)
is_prefill, is_deepseek_v3_r1)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
@@ -1285,7 +1288,8 @@ class AscendFusedMoE(FusedMoE):
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
and fused_moe_state != FusedMoEState.AllGatherEP):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
@@ -1303,7 +1307,8 @@ class AscendFusedMoE(FusedMoE):
else:
final_hidden_states = e_hidden_states
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
or fused_moe_state == FusedMoEState.AllGatherEP):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)

View File

@@ -22,12 +22,13 @@ import torch.distributed as dist
import torch_npu
from vllm.distributed import GroupCoordinator
import vllm_ascend.envs as envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, npu_stream_switch,
npu_wait_tensor)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
dispose_tensor, get_fused_moe_state,
npu_stream_switch, npu_wait_tensor)
def apply_mlp(hidden_states: torch.Tensor,
@@ -346,6 +347,95 @@ def fused_experts_with_all2all(
return final_hidden_states
def fused_experts_with_allgather(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
batch_size, hidden_size = hidden_states.shape
topk_weights = topk_weights.to(hidden_states.dtype)
ep_group = get_ep_group().device_group
ep_rank = torch.distributed.get_rank(group=ep_group)
ep_size = torch.distributed.get_world_size(ep_group)
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_size
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
scale=pertoken_scale,
offset=None,
active_num=num_tokens * top_k,
expert_num=global_num_experts,
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
],
quant_mode=-1,
row_idx_type=1)
group_list_type = 1
sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
expanded_x_idx)
row_index = expanded_x_idx // topk_ids.shape[-1]
row_index = row_index.to(torch.int64)
share_input = torch.zeros((batch_size, hidden_size),
dtype=torch.bfloat16,
device="npu")
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32)[0]
# act_fn: swiglu
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale.to(torch.float32),
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)
final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
hidden_states,
w2,
scale=w2_scale.to(torch.float32),
bias=None,
pertoken_scale=pertoken_scale.view(-1),
group_list=expert_tokens,
shared_input=share_input,
logit=sorted_topk_weight.to(torch.float32),
row_index=row_index,
output_bs=batch_size).to(torch.bfloat16)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
@@ -623,8 +713,10 @@ class AscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk当前写8
@@ -661,8 +753,19 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype)
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
if fused_moe_state == FusedMoEState.MC2:
is_prefill, is_deepseek_v3_r1)
if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
@@ -713,6 +816,8 @@ class AscendW8A8DynamicFusedMoEMethod:
1, 2).contiguous()
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
if envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(

View File

@@ -394,11 +394,18 @@ class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
AllGatherEP = 3
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool):
if ep_size == 1:
def get_fused_moe_state(ep_size: int, with_prefill: bool,
is_deepseek_v3_r1: bool):
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
# only supports deepseek v3/r1
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
and is_deepseek_v3_r1):
return FusedMoEState.AllGatherEP
elif ep_size == 1:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill: