### What this PR does / why we need it?
This PR refactors the MoE (Mixture of Experts) communication logic by
introducing a strategy pattern. It defines an abstract base class,
`MoECommMethod`, which encapsulates different communication strategies
for MoE layers. By decoupling the MoE implementation from any single
communication method, this change makes it simpler to add, replace, or
optimize communication strategies in the future.
Plan / Roadmap
1. Introduce `MoECommMethod`, implement `AllGatherImpl`, and adapt ACL
Graph handling to cover all scenarios (this PR).
2. Implement `MC2CommImpl` and `AllToAllCommImpl` to optimize
performance in specific scenarios.
3. Enable W8A8 / Int8 models to use `unified_fused_experts`.
Other notes
* Data-parallel (DP) communication currently does not work with vLLM's
dispatch/combine mechanisms; an alternative approach is required to
resolve this incompatibility.
- vLLM version: v0.10.0
- vLLM main:
f7ad6a1eb3
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
450 lines
17 KiB
Python
450 lines
17 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
import torch_npu
|
|
from transformers.configuration_utils import PretrainedConfig
|
|
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
|
from vllm.utils import direct_register_custom_op
|
|
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
|
|
|
|
|
class MoECommMethod(ABC):
|
|
"""Base class for MoE communication methods."""
|
|
|
|
def __init__(
|
|
self,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
hf_config: PretrainedConfig,
|
|
):
|
|
self.device = device
|
|
self.dtype = dtype
|
|
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
|
|
# global_num_experts may be called num_experts or n_routed_experts in different models.
|
|
possible_keys = ["num_experts", "n_routed_experts"]
|
|
for key in possible_keys:
|
|
if hasattr(hf_config, key):
|
|
self.global_num_experts = getattr(hf_config, key)
|
|
break
|
|
else:
|
|
self.global_num_experts = 0
|
|
|
|
@abstractmethod
|
|
def _pre_process(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
"""Pre-process before MLP.
|
|
|
|
Args:
|
|
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
|
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
|
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
|
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
|
Mapping from global expert IDs to local expert IDs.
|
|
num_experts (int): Number of local experts (experts on this device).
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
|
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
|
(num_tokens * top_k_num, hidden_size) after permuting
|
|
hidden_states based on topk_ids.
|
|
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
|
Number of tokens assigned to each expert.
|
|
- group_list_type (int): Type of group list, 0 for `cumsum`
|
|
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
|
to determine how to handle the output.
|
|
Raises:
|
|
NotImplementedError: If the method is not implemented in the subclass.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
"""Post-process after MLP.
|
|
|
|
Args:
|
|
mlp_output (torch.Tensor): Tensor of shape
|
|
(num_tokens * top_k_num, hidden_size) after MLP.
|
|
hidden_states (torch.Tensor): Tensor of shape
|
|
(num_tokens, hidden_size) to be updated with the final output.
|
|
"""
|
|
pass
|
|
|
|
|
|
class DummyCommImpl(MoECommMethod):
|
|
|
|
def _pre_process(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
"""Dummy implementation, see moe_comm_pre_process_fake for details."""
|
|
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
|
|
expert_map, num_experts)
|
|
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
"""Dummy implementation that does nothing."""
|
|
pass
|
|
|
|
|
|
class NativeAllGatherCommImpl(MoECommMethod):
|
|
"""This implementation should be compatible with all scenarios.
|
|
|
|
Note that this implementation purely consists of native PyTorch ops
|
|
and does not use any NPU-specific ops. So the performance may not be optimal.
|
|
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
|
"""
|
|
|
|
def _pre_process(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
num_tokens = hidden_states.shape[0]
|
|
|
|
# Generate token indices and flatten
|
|
token_indices = torch.arange(num_tokens,
|
|
device=self.device,
|
|
dtype=torch.int64)
|
|
token_indices = (token_indices.unsqueeze(1).expand(
|
|
-1, self.top_k_num).reshape(-1))
|
|
|
|
# Flatten token-to-expert mappings and map to local experts
|
|
weights_flat = topk_weights.view(-1)
|
|
experts_flat = topk_ids.view(-1)
|
|
local_experts_flat = (expert_map[experts_flat]
|
|
if expert_map is not None else experts_flat)
|
|
|
|
# Filter valid token-expert pairs
|
|
mask = local_experts_flat != -1
|
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
# So we need to filter out invalid tokens by zeroing their weights.
|
|
# This is a workaround and should be removed after the issue is fixed
|
|
filtered_weights = torch.where(mask, weights_flat,
|
|
torch.zeros_like(weights_flat)).to(
|
|
self.dtype)
|
|
filtered_experts = torch.where(
|
|
mask,
|
|
local_experts_flat,
|
|
torch.full_like(local_experts_flat, num_experts),
|
|
).to(topk_ids.dtype)
|
|
|
|
# Sort by local expert IDs
|
|
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
|
self.sorted_token_indices = token_indices[sort_indices]
|
|
self.sorted_weights = filtered_weights[sort_indices]
|
|
|
|
# Compute token counts with minlength of num_experts
|
|
# This is equivalent to but faster than:
|
|
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
token_counts = torch.zeros(num_experts + 1,
|
|
device=self.device,
|
|
dtype=torch.int64)
|
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
|
expert_tokens = token_counts[:num_experts]
|
|
|
|
# Rearrange hidden_states
|
|
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
|
|
|
group_list_type = 1 # `count` mode
|
|
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
|
|
|
final_hidden_states = torch.zeros_like(hidden_states)
|
|
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
|
mlp_output)
|
|
|
|
hidden_states[:] = final_hidden_states
|
|
|
|
|
|
class AllGatherCommImpl(MoECommMethod):
|
|
"""This implementation is the same as NativeAllGatherCommImpl,
|
|
but uses NPU-specific ops for better performance.
|
|
|
|
This implementation should be compatible with all scenarios, and
|
|
thus it is the default implementation for MoE communication methods.
|
|
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
|
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
|
to handle the token-to-expert mapping and communication efficiently.
|
|
|
|
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
|
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
|
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
|
for pre-processing and post-processing, respectively.
|
|
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
|
use `torch_npu.npu_moe_token_unpermute` instead.
|
|
This is a workaround and should be removed after the issue is fixed.
|
|
"""
|
|
|
|
def _pre_process(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor, # noqa: F841
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
num_tokens = hidden_states.shape[0]
|
|
|
|
self.topk_weights = topk_weights
|
|
self.topk_ids = topk_ids
|
|
|
|
first_expert_idx = 0
|
|
if expert_map is not None:
|
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
|
# So we need to filter out invalid tokens by zeroing their weights.
|
|
# This is a workaround and should be removed after the issue is fixed
|
|
mask = expert_map[topk_ids] != -1
|
|
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
|
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
|
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
|
|
|
first_expert_idx = get_ep_group().rank_in_group * num_experts
|
|
last_expert_idx = first_expert_idx + num_experts
|
|
|
|
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
|
torch_npu.npu_moe_init_routing_v2(
|
|
hidden_states,
|
|
topk_ids,
|
|
active_num=num_tokens * self.top_k_num,
|
|
expert_num=self.global_num_experts,
|
|
expert_tokens_num_type=1, # Only support `count` mode now
|
|
expert_tokens_num_flag=True, # Output `expert_tokens`
|
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
|
quant_mode=-1,
|
|
))
|
|
self.expanded_row_idx = expanded_row_idx
|
|
permuted_hidden_states = permuted_hidden_states
|
|
|
|
group_list_type = 1 # `count` mode
|
|
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
|
permuted_tokens=mlp_output,
|
|
sorted_indices=self.expanded_row_idx,
|
|
probs=self.topk_weights)
|
|
|
|
|
|
class MC2CommImpl(MoECommMethod):
|
|
"""This implementation is for the scenarios listed below:
|
|
1. `enable_expert_parallel=True`.
|
|
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
|
3. `enable_expert_parallel=False` is not supported.
|
|
|
|
This implementation uses the MC2 communication method, which is optimized for
|
|
Communication and Computation parallelism on Ascend devices.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
hf_config: PretrainedConfig,
|
|
):
|
|
super().__init__(device, dtype, hf_config)
|
|
|
|
# Shared communication configurations
|
|
ep_group = get_mc2_group()
|
|
self.ep_rank_id = ep_group.rank_in_group
|
|
self.ep_world_size = ep_group.world_size
|
|
self.tp_world_size = get_tp_group().world_size
|
|
|
|
device_group = ep_group.device_group
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
|
backend = device_group._get_backend(torch.device("npu"))
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
|
|
|
# Feature flags
|
|
self.enable_dispatch_v2 = hasattr(torch_npu,
|
|
"npu_moe_distribute_dispatch_v2")
|
|
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
|
self.need_extra_args = self.is_ascend_a3 # or is_torchair
|
|
|
|
# Intermediate tensors to be passed from pre_process to post_process
|
|
self.topk_ids = None
|
|
self.topk_weights = None
|
|
self.mc2_mask = None
|
|
self.assist_info_for_combine = None
|
|
self.ep_recv_counts = None
|
|
self.tp_recv_counts = None
|
|
|
|
def _pre_process(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
# Store tensors needed for post_process
|
|
self.topk_ids = topk_ids
|
|
self.topk_weights = topk_weights.to(torch.float32)
|
|
self.mc2_mask = get_forward_context().mc2_mask
|
|
|
|
dispatch_kwargs = {
|
|
"x": hidden_states,
|
|
"expert_ids": self.topk_ids,
|
|
"expert_shard_type": 0,
|
|
"shared_expert_rank_num": 0,
|
|
"moe_expert_num": self.global_num_experts,
|
|
"global_bs": 0,
|
|
"scales": None,
|
|
"quant_mode": 0,
|
|
"group_ep": self.moe_all_to_all_group_name,
|
|
"ep_world_size": self.ep_world_size,
|
|
"ep_rank_id": self.ep_rank_id,
|
|
}
|
|
|
|
if self.need_extra_args:
|
|
dispatch_kwargs.update({
|
|
"group_tp": self.moe_all_to_all_group_name,
|
|
"tp_world_size": 1,
|
|
"tp_rank_id": 0,
|
|
})
|
|
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
|
dispatch_kwargs.update({
|
|
"x_active_mask": self.mc2_mask,
|
|
})
|
|
|
|
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
|
|
|
(
|
|
permuted_hidden_states,
|
|
_, # dynamic_scale is not used
|
|
self.assist_info_for_combine,
|
|
expert_tokens,
|
|
self.ep_recv_counts,
|
|
self.tp_recv_counts,
|
|
) = dispatch(**dispatch_kwargs)[:6]
|
|
|
|
group_list_type = 1
|
|
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
combine_kwargs = {
|
|
"expand_x": mlp_output,
|
|
"expert_ids": self.topk_ids,
|
|
"expert_scales": self.topk_weights,
|
|
"expert_shard_type": 0,
|
|
"shared_expert_rank_num": 0,
|
|
"moe_expert_num": self.global_num_experts,
|
|
"global_bs": 0,
|
|
"ep_send_counts": self.ep_recv_counts,
|
|
"group_ep": self.moe_all_to_all_group_name,
|
|
"ep_world_size": self.ep_world_size,
|
|
"ep_rank_id": self.ep_rank_id,
|
|
}
|
|
|
|
if self.enable_dispatch_v2:
|
|
combine_kwargs[
|
|
"assist_info_for_combine"] = self.assist_info_for_combine
|
|
else:
|
|
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
|
|
|
if self.need_extra_args:
|
|
combine_kwargs.update({
|
|
"tp_send_counts": self.tp_recv_counts,
|
|
"group_tp": self.moe_all_to_all_group_name,
|
|
"tp_world_size": 1,
|
|
"tp_rank_id": 0,
|
|
})
|
|
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
|
combine_kwargs.update({
|
|
"x_active_mask": self.mc2_mask,
|
|
})
|
|
|
|
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
|
|
|
hidden_states[:] = combine(**combine_kwargs)
|
|
|
|
|
|
def moe_comm_pre_process(
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
"""This function is a wrapper for the pre_process method of the
|
|
MoECommMethod instance stored in the ForwardContext. So it can be
|
|
used as a custom op in the vllm framework.
|
|
"""
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.moe_comm_method
|
|
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
|
|
num_experts)
|
|
|
|
|
|
def moe_comm_pre_process_fake(
|
|
hidden_states: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
expert_map: torch.Tensor,
|
|
num_experts: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
|
"""This is a fake implementation of the pre_process method.
|
|
torch.compile will use this implementation to generate FX graph.
|
|
"""
|
|
top_k_num = topk_ids.shape[1]
|
|
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
|
|
expert_tokens = torch.zeros((num_experts, ),
|
|
dtype=torch.int64,
|
|
device=hidden_states.device)
|
|
group_list_type = 0
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
|
|
|
|
|
def moe_comm_post_process(mlp_output: torch.Tensor,
|
|
hidden_states: torch.Tensor) -> None:
|
|
"""This function is a wrapper for the post_process method of the
|
|
MoECommMethod instance stored in the ForwardContext. So it can be
|
|
used as a custom op in the vllm framework.
|
|
"""
|
|
forward_context: ForwardContext = get_forward_context()
|
|
self = forward_context.moe_comm_method
|
|
self._post_process(mlp_output, hidden_states)
|
|
return
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="moe_comm_pre_process",
|
|
op_func=moe_comm_pre_process,
|
|
mutates_args=[],
|
|
fake_impl=moe_comm_pre_process_fake,
|
|
dispatch_key="PrivateUse1",
|
|
)
|
|
|
|
direct_register_custom_op(
|
|
op_name="moe_comm_post_process",
|
|
op_func=moe_comm_post_process,
|
|
mutates_args=["hidden_states"],
|
|
fake_impl=lambda x, y: None, # No-op for fake implementation
|
|
dispatch_key="PrivateUse1",
|
|
)
|