support fused_moe_allgather_ep (#1335)
### What this PR does / why we need it? support fused_moe_allgather_ep ### How was this patch tested? It was tested by UT. Signed-off-by: lyj-jjj <liuyingjun5@huawei.com>
This commit is contained in:
82
tests/multicard/test_fused_moe_allgather_ep.py
Normal file
82
tests/multicard/test_fused_moe_allgather_ep.py
Normal file
@@ -0,0 +1,82 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep.
|
||||
|
||||
Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
|
||||
|
||||
@patch.dict(
|
||||
os.environ, {
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
|
||||
"TASK_QUEUE_ENABLE": "1",
|
||||
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
|
||||
})
|
||||
def test_generate_with_allgather():
|
||||
example_prompts = ["Hello, my name is"]
|
||||
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||
|
||||
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
|
||||
tensor_parallel_size=16,
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
dtype="auto",
|
||||
enable_expert_parallel=True,
|
||||
additional_config={
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
"chunked_prefill_enabled": False,
|
||||
},
|
||||
"expert_tensor_parallel_size": 1
|
||||
}) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
@patch.dict(
|
||||
os.environ, {
|
||||
"VLLM_USE_V1": "1",
|
||||
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
|
||||
"TASK_QUEUE_ENABLE": "1"
|
||||
})
|
||||
def test_generate_with_alltoall():
|
||||
example_prompts = ["Hello, my name is"]
|
||||
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
|
||||
|
||||
with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"),
|
||||
tensor_parallel_size=16,
|
||||
enforce_eager=True,
|
||||
max_model_len=1024,
|
||||
dtype="auto",
|
||||
enable_expert_parallel=True,
|
||||
additional_config={
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
"chunked_prefill_enabled": False,
|
||||
},
|
||||
"expert_tensor_parallel_size": 1
|
||||
}) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
@@ -99,6 +99,11 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Whether to enable the trace recompiles from pytorch.
|
||||
"VLLM_ASCEND_TRACE_RECOMPILES":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
|
||||
# Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and
|
||||
# GroupedMatmulFinalizeRouting operators are combined to implement EP.
|
||||
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0'))
|
||||
),
|
||||
"VLLM_ASCEND_ENABLE_DBO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))),
|
||||
# Whether to enable the model execute time observe profile. Disable it when
|
||||
|
||||
@@ -988,8 +988,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
@@ -1025,7 +1026,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
||||
is_prefill)
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
@@ -1219,15 +1220,17 @@ class AscendFusedMoE(FusedMoE):
|
||||
real_top_k = self.top_k
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
||||
is_prefill)
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
||||
if num_tokens < tp_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||
@@ -1285,7 +1288,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
|
||||
if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather
|
||||
and fused_moe_state != FusedMoEState.AllGatherEP):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
@@ -1303,7 +1307,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
|
||||
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather
|
||||
or fused_moe_state == FusedMoEState.AllGatherEP):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
|
||||
@@ -22,12 +22,13 @@ import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.distributed import GroupCoordinator
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
get_fused_moe_state, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
|
||||
dispose_tensor, get_fused_moe_state,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
|
||||
|
||||
def apply_mlp(hidden_states: torch.Tensor,
|
||||
@@ -346,6 +347,95 @@ def fused_experts_with_all2all(
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_allgather(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None):
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
num_tokens = hidden_states.shape[0]
|
||||
batch_size, hidden_size = hidden_states.shape
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
ep_group = get_ep_group().device_group
|
||||
ep_rank = torch.distributed.get_rank(group=ep_group)
|
||||
ep_size = torch.distributed.get_world_size(ep_group)
|
||||
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
scale=pertoken_scale,
|
||||
offset=None,
|
||||
active_num=num_tokens * top_k,
|
||||
expert_num=global_num_experts,
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[
|
||||
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts
|
||||
],
|
||||
quant_mode=-1,
|
||||
row_idx_type=1)
|
||||
group_list_type = 1
|
||||
|
||||
sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0,
|
||||
expanded_x_idx)
|
||||
row_index = expanded_x_idx // topk_ids.shape[-1]
|
||||
row_index = row_index.to(torch.int64)
|
||||
share_input = torch.zeros((batch_size, hidden_size),
|
||||
dtype=torch.bfloat16,
|
||||
device="npu")
|
||||
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=expert_tokens,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale.to(torch.float32),
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=expert_tokens,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing(
|
||||
hidden_states,
|
||||
w2,
|
||||
scale=w2_scale.to(torch.float32),
|
||||
bias=None,
|
||||
pertoken_scale=pertoken_scale.view(-1),
|
||||
group_list=expert_tokens,
|
||||
shared_input=share_input,
|
||||
logit=sorted_topk_weight.to(torch.float32),
|
||||
row_index=row_index,
|
||||
output_bs=batch_size).to(torch.bfloat16)
|
||||
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
@@ -623,8 +713,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
assert router_logits.shape[
|
||||
1] == global_num_experts, "Number of global experts mismatch"
|
||||
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk当前写8
|
||||
@@ -661,8 +753,19 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
||||
is_prefill)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
if fused_moe_state == FusedMoEState.AllGatherEP:
|
||||
return fused_experts_with_allgather(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
elif fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -713,6 +816,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
if envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
|
||||
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
|
||||
@@ -394,11 +394,18 @@ class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
All2All = 1
|
||||
MC2 = 2
|
||||
AllGatherEP = 3
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def get_fused_moe_state(ep_size: int, with_prefill: bool):
|
||||
if ep_size == 1:
|
||||
def get_fused_moe_state(ep_size: int, with_prefill: bool,
|
||||
is_deepseek_v3_r1: bool):
|
||||
# the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep
|
||||
# only supports deepseek v3/r1
|
||||
if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1
|
||||
and is_deepseek_v3_r1):
|
||||
return FusedMoEState.AllGatherEP
|
||||
elif ep_size == 1:
|
||||
return FusedMoEState.AllGather
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
|
||||
Reference in New Issue
Block a user