[1/N][Feat] Support MoE models with ACL Graph and refactor MoE communication logic (#2125)
### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.
Plan / Roadmap
1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.
Other notes
* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.
- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
153
tests/e2e/multicard/moe/test_moe_comm.py
Normal file
153
tests/e2e/multicard/moe/test_moe_comm.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
from vllm import forward_context
|
||||
|
||||
from vllm_ascend.distributed import moe_comm_method
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
NativeAllGatherCommImpl)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [16, 128])
|
||||
@pytest.mark.parametrize("hidden_size", [64, 128])
|
||||
@pytest.mark.parametrize("global_num_experts", [8, 16])
|
||||
@pytest.mark.parametrize("top_k_num", [2, 4])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||
@pytest.mark.parametrize("num_local_experts", [4, 8])
|
||||
@pytest.mark.parametrize("ep_rank", [0, 1])
|
||||
def test_all_gather_comm_impl(
|
||||
num_tokens,
|
||||
hidden_size,
|
||||
global_num_experts,
|
||||
top_k_num,
|
||||
dtype,
|
||||
num_local_experts,
|
||||
ep_rank,
|
||||
):
|
||||
"""
|
||||
Tests the AllGatherCommImpl against the NativeAllGatherCommImpl.
|
||||
|
||||
This test compares the outputs of the NPU-optimized AllGatherCommImpl
|
||||
with a native PyTorch implementation (NativeAllGatherCommImpl) to ensure
|
||||
correctness across various configurations.
|
||||
"""
|
||||
if top_k_num > global_num_experts:
|
||||
pytest.skip("top_k_num cannot be greater than global_num_experts")
|
||||
if num_local_experts > global_num_experts:
|
||||
pytest.skip(
|
||||
"num_local_experts cannot be greater than global_num_experts")
|
||||
|
||||
device = torch.device("npu")
|
||||
hf_config = PretrainedConfig(
|
||||
num_experts_per_tok=top_k_num,
|
||||
num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
# Instantiate implementations
|
||||
native_impl = NativeAllGatherCommImpl(device, dtype, hf_config)
|
||||
|
||||
all_gather_impl = AllGatherCommImpl(device, dtype, hf_config)
|
||||
|
||||
# TODO: Find out if this is the correct way to mock the forward context and ep group
|
||||
# Mock get_forward_context to return an object with moe_comm_method
|
||||
forward_context._forward_context = SimpleNamespace(
|
||||
moe_comm_method=all_gather_impl)
|
||||
# Mock get_ep_group to return a fake group with the specified ep_rank
|
||||
fake_ep_group = SimpleNamespace(rank_in_group=ep_rank)
|
||||
moe_comm_method.get_ep_group = lambda: fake_ep_group
|
||||
|
||||
# --- Input Data ---
|
||||
hidden_states = torch.randn(num_tokens,
|
||||
hidden_size,
|
||||
device=device,
|
||||
dtype=dtype)
|
||||
topk_ids = torch.randint(0,
|
||||
global_num_experts, (num_tokens, top_k_num),
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
topk_weights = torch.rand(num_tokens, top_k_num, device=device).to(dtype)
|
||||
topk_weights = torch.nn.functional.softmax(topk_weights, dim=1)
|
||||
|
||||
num_experts = global_num_experts
|
||||
|
||||
expert_map = None
|
||||
if num_local_experts < global_num_experts:
|
||||
# Create a map where some experts are local and some are not
|
||||
expert_map = torch.full((global_num_experts, ), -1, device=device)
|
||||
expert_map[ep_rank * num_local_experts:(ep_rank + 1) *
|
||||
num_local_experts] = torch.arange(num_local_experts,
|
||||
device=device)
|
||||
num_experts = num_local_experts
|
||||
|
||||
# --- Run Native Implementation (Golden Reference) ---
|
||||
native_hidden_states_out = hidden_states.clone()
|
||||
(
|
||||
native_permuted_hidden,
|
||||
native_expert_tokens,
|
||||
_,
|
||||
) = native_impl._pre_process(hidden_states, topk_ids, topk_weights,
|
||||
expert_map, num_experts)
|
||||
# Simulate MLP output
|
||||
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
||||
native_impl._post_process(native_mlp_output, native_hidden_states_out)
|
||||
|
||||
# --- Run AllGather Implementation ---
|
||||
all_gather_hidden_states_out = hidden_states.clone()
|
||||
(
|
||||
all_gather_permuted_hidden,
|
||||
all_gather_expert_tokens,
|
||||
_,
|
||||
) = torch.ops.vllm.moe_comm_pre_process(hidden_states, topk_ids,
|
||||
topk_weights, expert_map,
|
||||
num_experts)
|
||||
|
||||
# Use the same simulated MLP output for a fair comparison
|
||||
all_gather_mlp_output = native_mlp_output.clone()
|
||||
|
||||
torch.ops.vllm.moe_comm_post_process(all_gather_mlp_output,
|
||||
all_gather_hidden_states_out)
|
||||
|
||||
# --- Assertions ---
|
||||
# Define tolerance based on dtype
|
||||
atol = 1e-3 if dtype == torch.float16 else 1e-2
|
||||
rtol = 1e-3 if dtype == torch.float16 else 1e-2
|
||||
|
||||
# 1. Compare expert_tokens from pre_process
|
||||
assert torch.allclose(native_expert_tokens.to(
|
||||
all_gather_expert_tokens.device),
|
||||
all_gather_expert_tokens,
|
||||
atol=atol,
|
||||
rtol=rtol), "Expert tokens do not match."
|
||||
|
||||
# 2. Compare permuted_hidden_states from pre_process
|
||||
num_valid_tokens = native_expert_tokens.sum()
|
||||
assert torch.allclose(native_permuted_hidden[:num_valid_tokens].to(
|
||||
all_gather_permuted_hidden.device),
|
||||
all_gather_permuted_hidden[:num_valid_tokens],
|
||||
atol=atol,
|
||||
rtol=rtol), "Permuted hidden states do not match."
|
||||
|
||||
# 3. Compare final hidden_states from post_process
|
||||
assert torch.allclose(native_hidden_states_out.to(
|
||||
all_gather_hidden_states_out.device),
|
||||
all_gather_hidden_states_out,
|
||||
atol=atol,
|
||||
rtol=rtol), "Final hidden states do not match."
|
||||
@@ -5,11 +5,12 @@ from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
@@ -54,6 +55,8 @@ def set_ascend_forward_context(
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
with_prefill: bool = True,
|
||||
in_profile_run: bool = False,
|
||||
reserved_mc2_mask: Optional[torch.Tensor] = None,
|
||||
moe_comm_method: Optional[MoECommMethod] = None,
|
||||
num_actual_tokens: Optional[int] = None,
|
||||
):
|
||||
"""A context manager that stores the current forward context,
|
||||
@@ -66,6 +69,7 @@ def set_ascend_forward_context(
|
||||
num_tokens=num_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp):
|
||||
forward_context = get_forward_context()
|
||||
forward_context.moe_comm_method = moe_comm_method
|
||||
forward_context.with_prefill = with_prefill
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
@@ -97,16 +101,17 @@ def set_ascend_forward_context(
|
||||
if num_tokens is not None:
|
||||
if num_actual_tokens is None:
|
||||
num_actual_tokens = num_tokens
|
||||
tp_world_size = get_tp_group().world_size
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
# NOTE: token num which need to pad to when mc2
|
||||
forward_context.padded_num_tokens = math.ceil(
|
||||
max_tokens_across_dp / tp_world_size) * tp_world_size
|
||||
|
||||
mc2_mask = torch.zeros(forward_context.padded_num_tokens,
|
||||
dtype=torch.bool,
|
||||
device=NPUPlatform.device_type)
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
if reserved_mc2_mask is not None:
|
||||
mc2_mask = reserved_mc2_mask[:forward_context.
|
||||
padded_num_tokens]
|
||||
mc2_mask[:num_actual_tokens] = True
|
||||
mc2_mask[num_actual_tokens:] = False
|
||||
forward_context.mc2_mask = mc2_mask
|
||||
|
||||
try:
|
||||
yield
|
||||
|
||||
449
vllm_ascend/distributed/moe_comm_method.py
Normal file
449
vllm_ascend/distributed/moe_comm_method.py
Normal file
@@ -0,0 +1,449 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoECommMethod(ABC):
|
||||
"""Base class for MoE communication methods."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
hf_config: PretrainedConfig,
|
||||
):
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
|
||||
# global_num_experts may be called num_experts or n_routed_experts in different models.
|
||||
possible_keys = ["num_experts", "n_routed_experts"]
|
||||
for key in possible_keys:
|
||||
if hasattr(hf_config, key):
|
||||
self.global_num_experts = getattr(hf_config, key)
|
||||
break
|
||||
else:
|
||||
self.global_num_experts = 0
|
||||
|
||||
@abstractmethod
|
||||
def _pre_process(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Pre-process before MLP.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
||||
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||||
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||||
Mapping from global expert IDs to local expert IDs.
|
||||
num_experts (int): Number of local experts (experts on this device).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||||
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after permuting
|
||||
hidden_states based on topk_ids.
|
||||
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||||
Number of tokens assigned to each expert.
|
||||
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||||
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||||
to determine how to handle the output.
|
||||
Raises:
|
||||
NotImplementedError: If the method is not implemented in the subclass.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _post_process(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Post-process after MLP.
|
||||
|
||||
Args:
|
||||
mlp_output (torch.Tensor): Tensor of shape
|
||||
(num_tokens * top_k_num, hidden_size) after MLP.
|
||||
hidden_states (torch.Tensor): Tensor of shape
|
||||
(num_tokens, hidden_size) to be updated with the final output.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DummyCommImpl(MoECommMethod):
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""Dummy implementation, see moe_comm_pre_process_fake for details."""
|
||||
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
|
||||
expert_map, num_experts)
|
||||
|
||||
def _post_process(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""Dummy implementation that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
class NativeAllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation should be compatible with all scenarios.
|
||||
|
||||
Note that this implementation purely consists of native PyTorch ops
|
||||
and does not use any NPU-specific ops. So the performance may not be optimal.
|
||||
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
||||
"""
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Generate token indices and flatten
|
||||
token_indices = torch.arange(num_tokens,
|
||||
device=self.device,
|
||||
dtype=torch.int64)
|
||||
token_indices = (token_indices.unsqueeze(1).expand(
|
||||
-1, self.top_k_num).reshape(-1))
|
||||
|
||||
# Flatten token-to-expert mappings and map to local experts
|
||||
weights_flat = topk_weights.view(-1)
|
||||
experts_flat = topk_ids.view(-1)
|
||||
local_experts_flat = (expert_map[experts_flat]
|
||||
if expert_map is not None else experts_flat)
|
||||
|
||||
# Filter valid token-expert pairs
|
||||
mask = local_experts_flat != -1
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
filtered_weights = torch.where(mask, weights_flat,
|
||||
torch.zeros_like(weights_flat)).to(
|
||||
self.dtype)
|
||||
filtered_experts = torch.where(
|
||||
mask,
|
||||
local_experts_flat,
|
||||
torch.full_like(local_experts_flat, num_experts),
|
||||
).to(topk_ids.dtype)
|
||||
|
||||
# Sort by local expert IDs
|
||||
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||||
self.sorted_token_indices = token_indices[sort_indices]
|
||||
self.sorted_weights = filtered_weights[sort_indices]
|
||||
|
||||
# Compute token counts with minlength of num_experts
|
||||
# This is equivalent to but faster than:
|
||||
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||||
token_counts = torch.zeros(num_experts + 1,
|
||||
device=self.device,
|
||||
dtype=torch.int64)
|
||||
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||||
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||||
expert_tokens = token_counts[:num_experts]
|
||||
|
||||
# Rearrange hidden_states
|
||||
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
|
||||
def _post_process(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
||||
|
||||
final_hidden_states = torch.zeros_like(hidden_states)
|
||||
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||||
mlp_output)
|
||||
|
||||
hidden_states[:] = final_hidden_states
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
"""This implementation is the same as NativeAllGatherCommImpl,
|
||||
but uses NPU-specific ops for better performance.
|
||||
|
||||
This implementation should be compatible with all scenarios, and
|
||||
thus it is the default implementation for MoE communication methods.
|
||||
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||||
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
|
||||
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||||
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||||
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||||
for pre-processing and post-processing, respectively.
|
||||
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||||
use `torch_npu.npu_moe_token_unpermute` instead.
|
||||
This is a workaround and should be removed after the issue is fixed.
|
||||
"""
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor, # noqa: F841
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
self.topk_weights = topk_weights
|
||||
self.topk_ids = topk_ids
|
||||
|
||||
first_expert_idx = 0
|
||||
if expert_map is not None:
|
||||
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||||
# So we need to filter out invalid tokens by zeroing their weights.
|
||||
# This is a workaround and should be removed after the issue is fixed
|
||||
mask = expert_map[topk_ids] != -1
|
||||
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
||||
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
||||
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
||||
|
||||
first_expert_idx = get_ep_group().rank_in_group * num_experts
|
||||
last_expert_idx = first_expert_idx + num_experts
|
||||
|
||||
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
||||
torch_npu.npu_moe_init_routing_v2(
|
||||
hidden_states,
|
||||
topk_ids,
|
||||
active_num=num_tokens * self.top_k_num,
|
||||
expert_num=self.global_num_experts,
|
||||
expert_tokens_num_type=1, # Only support `count` mode now
|
||||
expert_tokens_num_flag=True, # Output `expert_tokens`
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=-1,
|
||||
))
|
||||
self.expanded_row_idx = expanded_row_idx
|
||||
permuted_hidden_states = permuted_hidden_states
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
|
||||
def _post_process(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=mlp_output,
|
||||
sorted_indices=self.expanded_row_idx,
|
||||
probs=self.topk_weights)
|
||||
|
||||
|
||||
class MC2CommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
hf_config: PretrainedConfig,
|
||||
):
|
||||
super().__init__(device, dtype, hf_config)
|
||||
|
||||
# Shared communication configurations
|
||||
ep_group = get_mc2_group()
|
||||
self.ep_rank_id = ep_group.rank_in_group
|
||||
self.ep_world_size = ep_group.world_size
|
||||
self.tp_world_size = get_tp_group().world_size
|
||||
|
||||
device_group = ep_group.device_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
|
||||
# Feature flags
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
self.need_extra_args = self.is_ascend_a3 # or is_torchair
|
||||
|
||||
# Intermediate tensors to be passed from pre_process to post_process
|
||||
self.topk_ids = None
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
self.tp_recv_counts = None
|
||||
|
||||
def _pre_process(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
# Store tensors needed for post_process
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights.to(torch.float32)
|
||||
self.mc2_mask = get_forward_context().mc2_mask
|
||||
|
||||
dispatch_kwargs = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.global_num_experts,
|
||||
"global_bs": 0,
|
||||
"scales": None,
|
||||
"quant_mode": 0,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
|
||||
if self.need_extra_args:
|
||||
dispatch_kwargs.update({
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
dispatch_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
_, # dynamic_scale is not used
|
||||
self.assist_info_for_combine,
|
||||
expert_tokens,
|
||||
self.ep_recv_counts,
|
||||
self.tp_recv_counts,
|
||||
) = dispatch(**dispatch_kwargs)[:6]
|
||||
|
||||
group_list_type = 1
|
||||
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
|
||||
def _post_process(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
combine_kwargs = {
|
||||
"expand_x": mlp_output,
|
||||
"expert_ids": self.topk_ids,
|
||||
"expert_scales": self.topk_weights,
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": self.global_num_experts,
|
||||
"global_bs": 0,
|
||||
"ep_send_counts": self.ep_recv_counts,
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
|
||||
if self.enable_dispatch_v2:
|
||||
combine_kwargs[
|
||||
"assist_info_for_combine"] = self.assist_info_for_combine
|
||||
else:
|
||||
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
||||
|
||||
if self.need_extra_args:
|
||||
combine_kwargs.update({
|
||||
"tp_send_counts": self.tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||||
combine_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
|
||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||
|
||||
hidden_states[:] = combine(**combine_kwargs)
|
||||
|
||||
|
||||
def moe_comm_pre_process(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""This function is a wrapper for the pre_process method of the
|
||||
MoECommMethod instance stored in the ForwardContext. So it can be
|
||||
used as a custom op in the vllm framework.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.moe_comm_method
|
||||
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
|
||||
num_experts)
|
||||
|
||||
|
||||
def moe_comm_pre_process_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
"""This is a fake implementation of the pre_process method.
|
||||
torch.compile will use this implementation to generate FX graph.
|
||||
"""
|
||||
top_k_num = topk_ids.shape[1]
|
||||
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
|
||||
expert_tokens = torch.zeros((num_experts, ),
|
||||
dtype=torch.int64,
|
||||
device=hidden_states.device)
|
||||
group_list_type = 0
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
|
||||
|
||||
def moe_comm_post_process(mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
"""This function is a wrapper for the post_process method of the
|
||||
MoECommMethod instance stored in the ForwardContext. So it can be
|
||||
used as a custom op in the vllm framework.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.moe_comm_method
|
||||
self._post_process(mlp_output, hidden_states)
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_comm_pre_process",
|
||||
op_func=moe_comm_pre_process,
|
||||
mutates_args=[],
|
||||
fake_impl=moe_comm_pre_process_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="moe_comm_post_process",
|
||||
op_func=moe_comm_post_process,
|
||||
mutates_args=["hidden_states"],
|
||||
fake_impl=lambda x, y: None, # No-op for fake implementation
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
@@ -19,12 +19,13 @@ from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import \
|
||||
UnquantizedFusedMoEMethod
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ops.fused_moe import (fused_experts, fused_experts_moge,
|
||||
select_experts)
|
||||
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
|
||||
unified_fused_experts)
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
@@ -95,20 +96,18 @@ def forward_oot(
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
# If use aclgraph, we need to set max_num_tokens to make
|
||||
# the input shape of `npu_moe_init_routing` fixed
|
||||
max_num_tokens = self.max_num_batched_tokens if self.use_aclgraph else None
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
|
||||
return fused_experts(
|
||||
return unified_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
max_num_tokens=max_num_tokens)
|
||||
moe_comm_method=moe_comm_method,
|
||||
)
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||
|
||||
@@ -43,6 +43,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
@@ -57,6 +58,62 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
def unified_fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
global_num_experts: Optional[int] = None,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[torch.Tensor] = None,
|
||||
w2_scale: Optional[torch.Tensor] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
moe_comm_method: Optional[MoECommMethod] = None,
|
||||
# For TorchAir graph
|
||||
is_torchair: bool = False,
|
||||
# For Cube/Vector parallel
|
||||
shared_experts: Optional[Any] = None,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Check constraints
|
||||
assert hidden_states.shape[1] == w1.shape[2], (
|
||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
||||
|
||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16
|
||||
]
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
permuted_hidden_states, expert_tokens, group_list_type = torch.ops.vllm.moe_comm_pre_process(
|
||||
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
|
||||
mlp_output = apply_mlp(
|
||||
permuted_hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
)
|
||||
torch.ops.vllm.moe_comm_post_process(mlp_output, hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
max_row_per_ep_rank: int, num_tokens: int,
|
||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -205,8 +205,15 @@ class NPUPlatform(Platform):
|
||||
register_ascend_customop()
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1, use_mla):
|
||||
def get_attn_backend_cls(cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink=False):
|
||||
if not use_v1:
|
||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ import types
|
||||
import weakref
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Type, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
@@ -43,7 +43,7 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
|
||||
get_tp_group)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import DPMetadata, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
@@ -79,6 +79,9 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
DummyCommImpl,
|
||||
MoECommMethod)
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
@@ -335,7 +338,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.use_aclgraph = (self.vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not self.model_config.enforce_eager)
|
||||
and not self.model_config.enforce_eager and
|
||||
not ascend_config.torchair_graph_config.enabled)
|
||||
self.aclgraph_batch_sizes = list(
|
||||
reversed(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
@@ -375,6 +379,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||
|
||||
self.reserved_mc2_mask = torch.zeros(
|
||||
512,
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.moe_comm_method = AllGatherCommImpl
|
||||
|
||||
def check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
@@ -1003,6 +1015,32 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
mm_embeds.append(mm_embeds_item)
|
||||
return mm_embeds
|
||||
|
||||
def get_dp_padding(self,
|
||||
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
|
||||
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
|
||||
Please note that vLLM may refactor or modify this function over time,
|
||||
at present, we are using the version introduced in PR #18935.
|
||||
"""
|
||||
dp_size = self.vllm_config.parallel_config.data_parallel_size
|
||||
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
# For DP: Don't pad when setting enforce_eager.
|
||||
# This lets us set enforce_eager on the prefiller in a P/D setup and
|
||||
# still use ACL graphs (enabled by this padding) on the decoder.
|
||||
|
||||
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
|
||||
# Early exit.
|
||||
return 0, None
|
||||
|
||||
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
|
||||
num_tokens, dp_size, dp_rank)
|
||||
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
|
||||
dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
|
||||
|
||||
def _process_reqs(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@@ -1025,6 +1063,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
|
||||
# Padding for DP
|
||||
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
|
||||
num_input_tokens)
|
||||
num_input_tokens += num_pad
|
||||
|
||||
modified_batch = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
if modified_batch:
|
||||
@@ -1250,13 +1293,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
moe_comm_method = self.moe_comm_method
|
||||
|
||||
# NOTE: Currently this padding logic is really messy,
|
||||
# MC2 may not be available in eager mode
|
||||
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
|
||||
if self.use_aclgraph:
|
||||
num_tokens_across_dp = num_tokens_across_dp_native
|
||||
else:
|
||||
num_input_tokens = padded_num_tokens_across_dp
|
||||
|
||||
# Run forward pass
|
||||
with set_ascend_forward_context(
|
||||
attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=padded_num_tokens_across_dp,
|
||||
num_tokens=num_input_tokens,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method(self.device, self.dtype,
|
||||
self.model_config.hf_config),
|
||||
num_actual_tokens=total_num_scheduled_tokens):
|
||||
with ProfileExecuteDuration().capture_async("forward"):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
@@ -1865,6 +1921,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
skip_attn: bool = True,
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
moe_comm_method: Type[MoECommMethod] = DummyCommImpl,
|
||||
) -> torch.Tensor:
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -1932,6 +1989,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
with_prefill=with_prefill,
|
||||
in_profile_run=self.in_profile_run,
|
||||
reserved_mc2_mask=self.reserved_mc2_mask,
|
||||
moe_comm_method=moe_comm_method(
|
||||
self.device, self.dtype, self.model_config.hf_config),
|
||||
num_actual_tokens=0,
|
||||
):
|
||||
hidden_states = self._generate_dummy_run_hidden_states(
|
||||
@@ -2328,13 +2388,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Trigger ACL graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
# TODO(zzzzwwjj): Check dummy_run with ACL Graph and full graph mode
|
||||
with graph_capture(device=self.device):
|
||||
skip_attn = not self.vllm_config.compilation_config.full_cuda_graph
|
||||
for num_tokens in reversed(self.aclgraph_batch_sizes):
|
||||
for _ in range(self.vllm_config.compilation_config.
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
skip_attn=skip_attn,
|
||||
moe_comm_method=self.moe_comm_method,
|
||||
)
|
||||
self._dummy_run(
|
||||
num_tokens,
|
||||
skip_attn=skip_attn,
|
||||
moe_comm_method=self.moe_comm_method,
|
||||
)
|
||||
|
||||
def capture_model(self) -> None:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user