From 5177bef87a21331dcca11159d3d1438075cbd74e Mon Sep 17 00:00:00 2001 From: lyj-jjj Date: Mon, 23 Jun 2025 22:03:38 +0800 Subject: [PATCH] support fused_moe_allgather_ep (#1335) ### What this PR does / why we need it? support fused_moe_allgather_ep ### How was this patch tested? It was tested by UT. Signed-off-by: lyj-jjj --- .../multicard/test_fused_moe_allgather_ep.py | 82 ++++++++++++ vllm_ascend/envs.py | 5 + vllm_ascend/ops/fused_moe.py | 17 ++- vllm_ascend/quantization/w8a8_dynamic.py | 117 +++++++++++++++++- vllm_ascend/utils.py | 11 +- 5 files changed, 218 insertions(+), 14 deletions(-) create mode 100644 tests/multicard/test_fused_moe_allgather_ep.py diff --git a/tests/multicard/test_fused_moe_allgather_ep.py b/tests/multicard/test_fused_moe_allgather_ep.py new file mode 100644 index 0000000..1e63878 --- /dev/null +++ b/tests/multicard/test_fused_moe_allgather_ep.py @@ -0,0 +1,82 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Execute the inference of fused_moe_allgather_ep and fused_moe_alltoall_ep. + +Run 'pytest tests/multicard/test_fused_moe_allgather_ep.py'. +""" + +import os +from unittest.mock import patch + +from modelscope import snapshot_download # type: ignore +from vllm import SamplingParams + +from tests.conftest import VllmRunner + + +@patch.dict( + os.environ, { + "VLLM_USE_V1": "1", + "VLLM_WORKER_MULTIPROC_METHOD": "spawn", + "TASK_QUEUE_ENABLE": "1", + "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1" + }) +def test_generate_with_allgather(): + example_prompts = ["Hello, my name is"] + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + + with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"), + tensor_parallel_size=16, + enforce_eager=True, + max_model_len=1024, + dtype="auto", + enable_expert_parallel=True, + additional_config={ + "ascend_scheduler_config": { + "enabled": True, + "chunked_prefill_enabled": False, + }, + "expert_tensor_parallel_size": 1 + }) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@patch.dict( + os.environ, { + "VLLM_USE_V1": "1", + "VLLM_WORKER_MULTIPROC_METHOD": "spawn", + "TASK_QUEUE_ENABLE": "1" + }) +def test_generate_with_alltoall(): + example_prompts = ["Hello, my name is"] + sampling_params = SamplingParams(max_tokens=100, temperature=0.0) + + with VllmRunner(snapshot_download("vllm-ascend/DeepSeek-V3-Pruning"), + tensor_parallel_size=16, + enforce_eager=True, + max_model_len=1024, + dtype="auto", + enable_expert_parallel=True, + additional_config={ + "ascend_scheduler_config": { + "enabled": True, + "chunked_prefill_enabled": False, + }, + "expert_tensor_parallel_size": 1 + }) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) \ No newline at end of file diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 02ecd66..6599241 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -99,6 +99,11 @@ env_variables: Dict[str, Callable[[], Any]] = { # Whether to enable the trace recompiles from pytorch. "VLLM_ASCEND_TRACE_RECOMPILES": lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))), + # Whether to enable fused_experts_allgather_ep. MoeInitRoutingV3 and + # GroupedMatmulFinalizeRouting operators are combined to implement EP. + "VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": + lambda: bool(int(os.getenv("VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP", '0')) + ), "VLLM_ASCEND_ENABLE_DBO": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_DBO", '0'))), # Whether to enable the model execute time observe profile. Disable it when diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index c1c865b..d65f12c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -988,8 +988,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): **kwargs, ) -> torch.Tensor: + is_deepseek_v3_r1 = global_num_experts == 256 # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: + if is_deepseek_v3_r1: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, k=top_k, # topk当前写8 @@ -1025,7 +1026,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill) + is_prefill, is_deepseek_v3_r1) if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, @@ -1219,15 +1220,17 @@ class AscendFusedMoE(FusedMoE): real_top_k = self.top_k num_tokens, hidden_size = hidden_states.shape + is_deepseek_v3_r1 = self.global_num_experts == 256 fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, - is_prefill) + is_prefill, is_deepseek_v3_r1) if shared_experts: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: shared_hidden_states = shared_experts(hidden_states) tp_size = get_tensor_model_parallel_world_size() - if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather + and fused_moe_state != FusedMoEState.AllGatherEP): if num_tokens < tp_size: hidden_states = nn.functional.pad( hidden_states, (0, 0, 0, tp_size - num_tokens)) @@ -1285,7 +1288,8 @@ class AscendFusedMoE(FusedMoE): if isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states - if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if (tp_size > 1 and fused_moe_state != FusedMoEState.AllGather + and fused_moe_state != FusedMoEState.AllGatherEP): dist.all_gather(list(chunk_hidden_states), e_hidden_states, self.tp_group) final_hidden_states = torch.cat(chunk_hidden_states, dim=0) @@ -1303,7 +1307,8 @@ class AscendFusedMoE(FusedMoE): else: final_hidden_states = e_hidden_states - if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + if tp_size > 1 and (fused_moe_state == FusedMoEState.AllGather + or fused_moe_state == FusedMoEState.AllGatherEP): final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 372c29b..fb328b7 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -22,12 +22,13 @@ import torch.distributed as dist import torch_npu from vllm.distributed import GroupCoordinator +import vllm_ascend.envs as envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts -from vllm_ascend.utils import (FusedMoEState, dispose_tensor, - get_fused_moe_state, npu_stream_switch, - npu_wait_tensor) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState, + dispose_tensor, get_fused_moe_state, + npu_stream_switch, npu_wait_tensor) def apply_mlp(hidden_states: torch.Tensor, @@ -346,6 +347,95 @@ def fused_experts_with_all2all( return final_hidden_states +def fused_experts_with_allgather(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + num_tokens = hidden_states.shape[0] + batch_size, hidden_size = hidden_states.shape + topk_weights = topk_weights.to(hidden_states.dtype) + + ep_group = get_ep_group().device_group + ep_rank = torch.distributed.get_rank(group=ep_group) + ep_size = torch.distributed.get_world_size(ep_group) + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_size + + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) + + hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2( + hidden_states, + topk_ids, + scale=pertoken_scale, + offset=None, + active_num=num_tokens * top_k, + expert_num=global_num_experts, + expert_tokens_num_type=1, + expert_tokens_num_flag=True, + active_expert_range=[ + ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts + ], + quant_mode=-1, + row_idx_type=1) + group_list_type = 1 + + sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0, + expanded_x_idx) + row_index = expanded_x_idx // topk_ids.shape[-1] + row_index = row_index.to(torch.int64) + share_input = torch.zeros((batch_size, hidden_size), + dtype=torch.bfloat16, + device="npu") + + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=expert_tokens, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale.to(torch.float32), + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_tokens, + activate_left=True, + quant_mode=1, + ) + + final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing( + hidden_states, + w2, + scale=w2_scale.to(torch.float32), + bias=None, + pertoken_scale=pertoken_scale.view(-1), + group_list=expert_tokens, + shared_input=share_input, + logit=sorted_topk_weight.to(torch.float32), + row_index=row_index, + output_bs=batch_size).to(torch.bfloat16) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + + return final_hidden_states + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -623,8 +713,10 @@ class AscendW8A8DynamicFusedMoEMethod: assert router_logits.shape[ 1] == global_num_experts, "Number of global experts mismatch" + is_deepseek_v3_r1 = global_num_experts == 256 + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern - if global_num_experts == 256: + if is_deepseek_v3_r1: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( router_logits, k=top_k, # topk当前写8 @@ -661,8 +753,19 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) fused_moe_state = get_fused_moe_state(self.ep_group.world_size, - is_prefill) - if fused_moe_state == FusedMoEState.MC2: + is_prefill, is_deepseek_v3_r1) + if fused_moe_state == FusedMoEState.AllGatherEP: + return fused_experts_with_allgather( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + expert_map=expert_map) + elif fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -713,6 +816,8 @@ class AscendW8A8DynamicFusedMoEMethod: 1, 2).contiguous() layer.w2_weight.data = layer.w2_weight.data.transpose( 1, 2).contiguous() + if envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP: + torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ) layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( layer.w13_weight_scale.data.shape[0], -1) layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 1a59036..4764526 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -394,11 +394,18 @@ class FusedMoEState(Enum): AllGather = 0 All2All = 1 MC2 = 2 + AllGatherEP = 3 # TODO(zzzzwwjj): add soc_version to choose branch -def get_fused_moe_state(ep_size: int, with_prefill: bool): - if ep_size == 1: +def get_fused_moe_state(ep_size: int, with_prefill: bool, + is_deepseek_v3_r1: bool): + # the fusion operator torch_npu.npu_grouped_matmul_finalize_routing called by allgather ep + # only supports deepseek v3/r1 + if (envs.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP and ep_size > 1 + and is_deepseek_v3_r1): + return FusedMoEState.AllGatherEP + elif ep_size == 1: return FusedMoEState.AllGather # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. elif ep_size < 16 or with_prefill: