[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)

### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.

Plan / Roadmap

1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.

Other notes

* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.

- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3

---------

Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
yiz-liu
2025-08-12 21:10:20 +08:00
committed by GitHub
parent 1a70564e7c
commit 992271b027
7 changed files with 764 additions and 26 deletions

View File

@@ -0,0 +1,153 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
from types import SimpleNamespace
import pytest
import torch
from transformers import PretrainedConfig
from vllm import forward_context
from vllm_ascend.distributed import moe_comm_method
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
NativeAllGatherCommImpl)
@pytest.mark.parametrize("num_tokens", [16, 128])
@pytest.mark.parametrize("hidden_size", [64, 128])
@pytest.mark.parametrize("global_num_experts", [8, 16])
@pytest.mark.parametrize("top_k_num", [2, 4])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("num_local_experts", [4, 8])
@pytest.mark.parametrize("ep_rank", [0, 1])
def test_all_gather_comm_impl(
num_tokens,
hidden_size,
global_num_experts,
top_k_num,
dtype,
num_local_experts,
ep_rank,
):
"""
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
This test compares the outputs of the NPU-optimized AllGatherCommImpl
with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure
correctness across various configurations.
"""
if top_k_num > global_num_experts:
pytest.skip("top_k_num cannot be greater than global_num_experts")
if num_local_experts > global_num_experts:
pytest.skip(
"num_local_experts cannot be greater than global_num_experts")
device = torch.device("npu")
hf_config = PretrainedConfig(
num_experts_per_tok=top_k_num,
num_experts=global_num_experts,
)
# Instantiate implementations
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
# TODO: Find out if this is the correct way to mock the forward context and ep group
# Mock get_forward_context to return an object with moe_comm_method
forward_context._forward_context = SimpleNamespace(
moe_comm_method=all_gather_impl)
# Mock get_ep_group to return a fake group with the specified ep_rank
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
moe_comm_method.get_ep_group = lambda: fake_ep_group
# --- Input Data ---
hidden_states = torch.randn(num_tokens,
hidden_size,
device=device,
dtype=dtype)
topk_ids = torch.randint(0,
global_num_experts, (num_tokens, top_k_num),
device=device,
dtype=torch.int32)
topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype)
topk_weights = torch.nn.functional.softmax(topk_weights, dim=1)
num_experts = global_num_experts
expert_map = None
if num_local_experts < global_num_experts:
# Create a map where some experts are local and some are not
expert_map = torch.full((global_num_experts, ), -1, device=device)
expert_map[ep_rank * num_local_experts:(ep_rank + 1) *
num_local_experts] = torch.arange(num_local_experts,
device=device)
num_experts = num_local_experts
# --- Run Native Implementation (Golden Reference) ---
native_hidden_states_out = hidden_states.clone()
(
native_permuted_hidden,
native_expert_tokens,
_,
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
expert_map, num_experts)
# Simulate MLP output
native_mlp_output = torch.randn_like(native_permuted_hidden)
native_impl._post_process(native_mlp_output, native_hidden_states_out)
# --- Run AllGather Implementation ---
all_gather_hidden_states_out = hidden_states.clone()
(
all_gather_permuted_hidden,
all_gather_expert_tokens,
_,
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
topk_weights, expert_map,
num_experts)
# Use the same simulated MLP output for a fair comparison
all_gather_mlp_output = native_mlp_output.clone()
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
all_gather_hidden_states_out)
# --- Assertions ---
# Define tolerance based on dtype
atol = 1e-3 if dtype == torch.float16 else 1e-2
rtol = 1e-3 if dtype == torch.float16 else 1e-2
# 1. Compare expert_tokens from pre_process
assert torch.allclose(native_expert_tokens.to(
all_gather_expert_tokens.device),
all_gather_expert_tokens,
atol=atol,
rtol=rtol), "Expert tokens do not match."
# 2. Compare permuted_hidden_states from pre_process
num_valid_tokens = native_expert_tokens.sum()
assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to(
all_gather_permuted_hidden.device),
all_gather_permuted_hidden[:num_valid_tokens],
atol=atol,
rtol=rtol), "Permuted hidden states do not match."
# 3. Compare final hidden_states from post_process
assert torch.allclose(native_hidden_states_out.to(
all_gather_hidden_states_out.device),
all_gather_hidden_states_out,
atol=atol,
rtol=rtol), "Final hidden states do not match."

View File

@@ -5,11 +5,12 @@ from typing import Any, Optional
import torch
from vllm.config import VllmConfig
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context, set_forward_context
import vllm_ascend.envs as envs
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
class FusedMoEState(Enum):
@@ -54,6 +55,8 @@ def set_ascend_forward_context(
num_tokens_across_dp: Optional[torch.Tensor] = None,
with_prefill: bool = True,
in_profile_run: bool = False,
reserved_mc2_mask: Optional[torch.Tensor] = None,
moe_comm_method: Optional[MoECommMethod] = None,
num_actual_tokens: Optional[int] = None,
):
"""A context manager that stores the current forward context,
@@ -66,6 +69,7 @@ def set_ascend_forward_context(
num_tokens=num_tokens,
num_tokens_across_dp=num_tokens_across_dp):
forward_context = get_forward_context()
forward_context.moe_comm_method = moe_comm_method
forward_context.with_prefill = with_prefill
ep_size = (get_ep_group().world_size if
vllm_config.parallel_config.enable_expert_parallel else 1)
@@ -97,16 +101,17 @@ def set_ascend_forward_context(
if num_tokens is not None:
if num_actual_tokens is None:
num_actual_tokens = num_tokens
tp_world_size = get_tp_group().world_size
tp_world_size = get_tensor_model_parallel_world_size()
# NOTE: token num which need to pad to when mc2
forward_context.padded_num_tokens = math.ceil(
max_tokens_across_dp / tp_world_size) * tp_world_size
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
dtype=torch.bool,
device=NPUPlatform.device_type)
mc2_mask[:num_actual_tokens] = True
forward_context.mc2_mask = mc2_mask
if reserved_mc2_mask is not None:
mc2_mask = reserved_mc2_mask[:forward_context.
padded_num_tokens]
mc2_mask[:num_actual_tokens] = True
mc2_mask[num_actual_tokens:] = False
forward_context.mc2_mask = mc2_mask
try:
yield

View File

@@ -0,0 +1,449 @@
from abc import ABC, abstractmethod
import torch
import torch_npu
from transformers.configuration_utils import PretrainedConfig
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
hf_config: PretrainedConfig,
):
self.device = device
self.dtype = dtype
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
# global_num_experts may be called num_experts or n_routed_experts in different models.
possible_keys = ["num_experts", "n_routed_experts"]
for key in possible_keys:
if hasattr(hf_config, key):
self.global_num_experts = getattr(hf_config, key)
break
else:
self.global_num_experts = 0
@abstractmethod
def _pre_process(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Pre-process before MLP.
Args:
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
Mapping from global expert IDs to local expert IDs.
num_experts (int): Number of local experts (experts on this device).
Returns:
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
- permuted_hidden_states (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after permuting
hidden_states based on topk_ids.
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
Number of tokens assigned to each expert.
- group_list_type (int): Type of group list, 0 for `cumsum`
and 1 for `count`. This is mainly for `npu_grouped_matmul`
to determine how to handle the output.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
"""
pass
@abstractmethod
def _post_process(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Post-process after MLP.
Args:
mlp_output (torch.Tensor): Tensor of shape
(num_tokens * top_k_num, hidden_size) after MLP.
hidden_states (torch.Tensor): Tensor of shape
(num_tokens, hidden_size) to be updated with the final output.
"""
pass
class DummyCommImpl(MoECommMethod):
def _pre_process(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""Dummy implementation, see moe_comm_pre_process_fake for details."""
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
expert_map, num_experts)
def _post_process(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""Dummy implementation that does nothing."""
pass
class NativeAllGatherCommImpl(MoECommMethod):
"""This implementation should be compatible with all scenarios.
Note that this implementation purely consists of native PyTorch ops
and does not use any NPU-specific ops. So the performance may not be optimal.
But it is a good fallback for scenarios where NPU-specific ops are not available.
"""
def _pre_process(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
num_tokens = hidden_states.shape[0]
# Generate token indices and flatten
token_indices = torch.arange(num_tokens,
device=self.device,
dtype=torch.int64)
token_indices = (token_indices.unsqueeze(1).expand(
-1, self.top_k_num).reshape(-1))
# Flatten token-to-expert mappings and map to local experts
weights_flat = topk_weights.view(-1)
experts_flat = topk_ids.view(-1)
local_experts_flat = (expert_map[experts_flat]
if expert_map is not None else experts_flat)
# Filter valid token-expert pairs
mask = local_experts_flat != -1
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
filtered_weights = torch.where(mask, weights_flat,
torch.zeros_like(weights_flat)).to(
self.dtype)
filtered_experts = torch.where(
mask,
local_experts_flat,
torch.full_like(local_experts_flat, num_experts),
).to(topk_ids.dtype)
# Sort by local expert IDs
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
self.sorted_token_indices = token_indices[sort_indices]
self.sorted_weights = filtered_weights[sort_indices]
# Compute token counts with minlength of num_experts
# This is equivalent to but faster than:
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
token_counts = torch.zeros(num_experts + 1,
device=self.device,
dtype=torch.int64)
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
expert_tokens = token_counts[:num_experts]
# Rearrange hidden_states
permuted_hidden_states = hidden_states[self.sorted_token_indices]
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, group_list_type
def _post_process(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
final_hidden_states = torch.zeros_like(hidden_states)
final_hidden_states.index_add_(0, self.sorted_token_indices,
mlp_output)
hidden_states[:] = final_hidden_states
class AllGatherCommImpl(MoECommMethod):
"""This implementation is the same as NativeAllGatherCommImpl,
but uses NPU-specific ops for better performance.
This implementation should be compatible with all scenarios, and
thus it is the default implementation for MoE communication methods.
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
and `torch_npu.npu_moe_token_unpermute` for post-processing
to handle the token-to-expert mapping and communication efficiently.
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
for pre-processing and post-processing, respectively.
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
use `torch_npu.npu_moe_token_unpermute` instead.
This is a workaround and should be removed after the issue is fixed.
"""
def _pre_process(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor, # noqa: F841
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
num_tokens = hidden_states.shape[0]
self.topk_weights = topk_weights
self.topk_ids = topk_ids
first_expert_idx = 0
if expert_map is not None:
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# So we need to filter out invalid tokens by zeroing their weights.
# This is a workaround and should be removed after the issue is fixed
mask = expert_map[topk_ids] != -1
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
self.topk_weights = torch.where(mask, topk_weights, 0.0)
first_expert_idx = get_ep_group().rank_in_group * num_experts
last_expert_idx = first_expert_idx + num_experts
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
torch_npu.npu_moe_init_routing_v2(
hidden_states,
topk_ids,
active_num=num_tokens * self.top_k_num,
expert_num=self.global_num_experts,
expert_tokens_num_type=1, # Only support `count` mode now
expert_tokens_num_flag=True, # Output `expert_tokens`
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=-1,
))
self.expanded_row_idx = expanded_row_idx
permuted_hidden_states = permuted_hidden_states
group_list_type = 1 # `count` mode
return permuted_hidden_states, expert_tokens, group_list_type
def _post_process(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
permuted_tokens=mlp_output,
sorted_indices=self.expanded_row_idx,
probs=self.topk_weights)
class MC2CommImpl(MoECommMethod):
"""This implementation is for the scenarios listed below:
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
hf_config: PretrainedConfig,
):
super().__init__(device, dtype, hf_config)
# Shared communication configurations
ep_group = get_mc2_group()
self.ep_rank_id = ep_group.rank_in_group
self.ep_world_size = ep_group.world_size
self.tp_world_size = get_tp_group().world_size
device_group = ep_group.device_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
# Feature flags
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
self.need_extra_args = self.is_ascend_a3 # or is_torchair
# Intermediate tensors to be passed from pre_process to post_process
self.topk_ids = None
self.topk_weights = None
self.mc2_mask = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.tp_recv_counts = None
def _pre_process(
self,
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
# Store tensors needed for post_process
self.topk_ids = topk_ids
self.topk_weights = topk_weights.to(torch.float32)
self.mc2_mask = get_forward_context().mc2_mask
dispatch_kwargs = {
"x": hidden_states,
"expert_ids": self.topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.global_num_experts,
"global_bs": 0,
"scales": None,
"quant_mode": 0,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.need_extra_args:
dispatch_kwargs.update({
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
dispatch_kwargs.update({
"x_active_mask": self.mc2_mask,
})
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
(
permuted_hidden_states,
_, # dynamic_scale is not used
self.assist_info_for_combine,
expert_tokens,
self.ep_recv_counts,
self.tp_recv_counts,
) = dispatch(**dispatch_kwargs)[:6]
group_list_type = 1
return permuted_hidden_states, expert_tokens, group_list_type
def _post_process(self, mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
combine_kwargs = {
"expand_x": mlp_output,
"expert_ids": self.topk_ids,
"expert_scales": self.topk_weights,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": self.global_num_experts,
"global_bs": 0,
"ep_send_counts": self.ep_recv_counts,
"group_ep": self.moe_all_to_all_group_name,
"ep_world_size": self.ep_world_size,
"ep_rank_id": self.ep_rank_id,
}
if self.enable_dispatch_v2:
combine_kwargs[
"assist_info_for_combine"] = self.assist_info_for_combine
else:
combine_kwargs["expand_idx"] = self.assist_info_for_combine
if self.need_extra_args:
combine_kwargs.update({
"tp_send_counts": self.tp_recv_counts,
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if self.is_ascend_a3 and self.enable_dispatch_v2:
combine_kwargs.update({
"x_active_mask": self.mc2_mask,
})
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
hidden_states[:] = combine(**combine_kwargs)
def moe_comm_pre_process(
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""This function is a wrapper for the pre_process method of the
MoECommMethod instance stored in the ForwardContext. So it can be
used as a custom op in the vllm framework.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.moe_comm_method
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
num_experts)
def moe_comm_pre_process_fake(
hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
expert_map: torch.Tensor,
num_experts: int,
) -> tuple[torch.Tensor, torch.Tensor, int]:
"""This is a fake implementation of the pre_process method.
torch.compile will use this implementation to generate FX graph.
"""
top_k_num = topk_ids.shape[1]
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
expert_tokens = torch.zeros((num_experts, ),
dtype=torch.int64,
device=hidden_states.device)
group_list_type = 0
return permuted_hidden_states, expert_tokens, group_list_type
def moe_comm_post_process(mlp_output: torch.Tensor,
hidden_states: torch.Tensor) -> None:
"""This function is a wrapper for the post_process method of the
MoECommMethod instance stored in the ForwardContext. So it can be
used as a custom op in the vllm framework.
"""
forward_context: ForwardContext = get_forward_context()
self = forward_context.moe_comm_method
self._post_process(mlp_output, hidden_states)
return
direct_register_custom_op(
op_name="moe_comm_pre_process",
op_func=moe_comm_pre_process,
mutates_args=[],
fake_impl=moe_comm_pre_process_fake,
dispatch_key="PrivateUse1",
)
direct_register_custom_op(
op_name="moe_comm_post_process",
op_func=moe_comm_post_process,
mutates_args=["hidden_states"],
fake_impl=lambda x, y: None, # No-op for fake implementation
dispatch_key="PrivateUse1",
)

View File

@@ -19,12 +19,13 @@ from typing import Callable, Optional
import torch
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
select_experts)
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
unified_fused_experts)
from vllm_ascend.utils import is_310p
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
@@ -95,20 +96,18 @@ def forward_oot(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
# If use aclgraph, we need to set max_num_tokens to make
# the input shape of `npu_moe_init_routing` fixed
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
moe_comm_method = get_forward_context().moe_comm_method
return fused_experts(
return unified_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
max_num_tokens=max_num_tokens)
moe_comm_method=moe_comm_method,
)
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func

View File

@@ -43,6 +43,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
@@ -57,6 +58,62 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
def unified_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
moe_comm_method: Optional[MoECommMethod] = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[2], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process(
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
mlp_output = apply_mlp(
permuted_hidden_states,
w1,
w2,
expert_tokens,
group_list_type=group_list_type,
)
torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states)
return hidden_states
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
max_row_per_ep_rank: int, num_tokens: int,
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:

View File

@@ -205,8 +205,15 @@ class NPUPlatform(Platform):
register_ascend_customop()
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, use_mla):
def get_attn_backend_cls(cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink=False):
if not use_v1:
raise ValueError("vLLM Ascend does not support V0 engine.")

View File

@@ -26,7 +26,7 @@ import types
import weakref
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
import numpy as np
import numpy.typing as npt
@@ -43,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.forward_context import DPMetadata, get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@@ -79,6 +79,9 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
DummyCommImpl,
MoECommMethod)
from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@@ -335,7 +338,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.use_aclgraph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
and not self.model_config.enforce_eager and
not ascend_config.torchair_graph_config.enabled)
self.aclgraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
@@ -375,6 +379,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
self.reserved_mc2_mask = torch.zeros(
512,
dtype=torch.bool,
device=self.device,
)
self.moe_comm_method = AllGatherCommImpl
def check_batch_sizes_consistency(self) -> None:
if not dist.is_initialized():
return
@@ -1003,6 +1015,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
mm_embeds.append(mm_embeds_item)
return mm_embeds
def get_dp_padding(self,
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
Please note that vLLM may refactor or modify this function over time,
at present, we are using the version introduced in PR #18935.
"""
dp_size = self.vllm_config.parallel_config.data_parallel_size
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
# For DP: Don't pad when setting enforce_eager.
# This lets us set enforce_eager on the prefiller in a P/D setup and
# still use ACL graphs (enabled by this padding) on the decoder.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
dp_size,
device="cpu",
dtype=torch.int32)
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
def _process_reqs(
self,
scheduler_output: "SchedulerOutput",
@@ -1025,6 +1063,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
# Padding for DP
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
num_input_tokens)
num_input_tokens += num_pad
modified_batch = self.attn_metadata_builder.reorder_batch(
self.input_batch, scheduler_output)
if modified_batch:
@@ -1250,13 +1293,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
for k, v in self.intermediate_tensors.items()
})
moe_comm_method = self.moe_comm_method
# NOTE: Currently this padding logic is really messy,
# MC2 may not be available in eager mode
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
if self.use_aclgraph:
num_tokens_across_dp = num_tokens_across_dp_native
else:
num_input_tokens = padded_num_tokens_across_dp
# Run forward pass
with set_ascend_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=padded_num_tokens_across_dp,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(self.device, self.dtype,
self.model_config.hf_config),
num_actual_tokens=total_num_scheduled_tokens):
with ProfileExecuteDuration().capture_async("forward"):
self.maybe_setup_kv_connector(scheduler_output)
@@ -1865,6 +1921,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
skip_attn: bool = True,
with_prefill: bool = False,
is_torchair_compile: bool = False,
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
) -> torch.Tensor:
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
@@ -1932,6 +1989,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_tokens_across_dp=num_tokens_across_dp,
with_prefill=with_prefill,
in_profile_run=self.in_profile_run,
reserved_mc2_mask=self.reserved_mc2_mask,
moe_comm_method=moe_comm_method(
self.device, self.dtype, self.model_config.hf_config),
num_actual_tokens=0,
):
hidden_states = self._generate_dummy_run_hidden_states(
@@ -2328,13 +2388,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Trigger ACL graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
with graph_capture(device=self.device):
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
for num_tokens in reversed(self.aclgraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(num_tokens)
self._dummy_run(num_tokens)
self._dummy_run(
num_tokens,
skip_attn=skip_attn,
moe_comm_method=self.moe_comm_method,
)
self._dummy_run(
num_tokens,
skip_attn=skip_attn,
moe_comm_method=self.moe_comm_method,
)
def capture_model(self) -> None:
start_time = time.perf_counter()