[refactor] Refactoring AscendFusedMoE (#1229)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-06-17 17:49:03 +08:00
committed by GitHub
parent 05dec7eda9
commit 23ca68d0c8
9 changed files with 150 additions and 204 deletions

View File

@@ -21,11 +21,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch_npu
from torch import nn
from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator,
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod,
determine_expert_map)
@@ -36,10 +38,10 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, npu_stream_switch,
npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -845,8 +847,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
ep_group = get_ep_group()
self.ep_size = ep_group.world_size
self.ep_group = get_ep_group()
self.ep_size = self.ep_group.world_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.max_model_len = vllm_config.model_config.max_model_len
@@ -855,7 +857,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = ep_group.device_group
device_group = self.ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
@@ -931,7 +933,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
if VLLM_ENABLE_MC2 and not is_prefill:
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
@@ -942,7 +946,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts)
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
@@ -1022,9 +1026,6 @@ class AscendFusedMoE(FusedMoE):
get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config))
self.moe_parallel_config.ep_size = get_ep_group().world_size
self.moe_parallel_config.tp_size = get_etp_group().world_size
self.top_k = top_k
self.num_experts = num_experts
self.global_num_experts = num_experts
@@ -1066,10 +1067,9 @@ class AscendFusedMoE(FusedMoE):
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
@@ -1109,6 +1109,8 @@ class AscendFusedMoE(FusedMoE):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.ep_group = get_ep_group()
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self,
@@ -1125,25 +1127,45 @@ class AscendFusedMoE(FusedMoE):
else:
real_top_k = self.top_k
# MC2 ag/rs broadcast/all_reduce
# prefill_req x x √
# decode_req √ x √
# graph_mode x
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
elif self.torchair_graph_enabled and not is_prefill:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
else:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
num_tokens, hidden_size = hidden_states.shape
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
is_prefill)
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if num_tokens < tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled or is_prefill:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
if num_tokens < max_num_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
# Matrix multiply.
e_hidden_states = self.quant_method.apply(
@@ -1167,36 +1189,36 @@ class AscendFusedMoE(FusedMoE):
shared_experts=shared_experts,
)
if shared_experts is not None:
# Provide dummy implementation of "non-separated" shared experts.
if not isinstance(e_hidden_states, tuple):
return e_hidden_states, shared_experts(hidden_states)
else:
return e_hidden_states
if shared_experts:
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
elif self.torchair_graph_enabled and not is_prefill:
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
else:
e_hidden_states = get_ep_group().combine(e_hidden_states)
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < tp_size:
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else:
final_hidden_states = e_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states)
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return e_hidden_states
if shared_experts:
return final_hidden_states, shared_hidden_states
else:
return final_hidden_states
# ----------------------------------------- TBO-related --------------------------------------------