450 lines
17 KiB
Python
450 lines
17 KiB
Python
|
|
from abc import ABC, abstractmethod
|
||
|
|
|
||
|
|
import torch
|
||
|
|
import torch_npu
|
||
|
|
from transformers.configuration_utils import PretrainedConfig
|
||
|
|
from vllm.distributed.parallel_state import get_ep_group, get_tp_group
|
||
|
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||
|
|
from vllm.utils import direct_register_custom_op
|
||
|
|
|
||
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||
|
|
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||
|
|
|
||
|
|
|
||
|
|
class MoECommMethod(ABC):
|
||
|
|
"""Base class for MoE communication methods."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
hf_config: PretrainedConfig,
|
||
|
|
):
|
||
|
|
self.device = device
|
||
|
|
self.dtype = dtype
|
||
|
|
self.top_k_num = getattr(hf_config, "num_experts_per_tok", 0)
|
||
|
|
# global_num_experts may be called num_experts or n_routed_experts in different models.
|
||
|
|
possible_keys = ["num_experts", "n_routed_experts"]
|
||
|
|
for key in possible_keys:
|
||
|
|
if hasattr(hf_config, key):
|
||
|
|
self.global_num_experts = getattr(hf_config, key)
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
self.global_num_experts = 0
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def _pre_process(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor,
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
"""Pre-process before MLP.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
hidden_states (torch.Tensor): Tensor of shape (num_tokens, hidden_size)
|
||
|
|
topk_ids (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||
|
|
topk_weights (torch.Tensor): Tensor of shape (num_tokens, top_k_num)
|
||
|
|
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||
|
|
Mapping from global expert IDs to local expert IDs.
|
||
|
|
num_experts (int): Number of local experts (experts on this device).
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||
|
|
- permuted_hidden_states (torch.Tensor): Tensor of shape
|
||
|
|
(num_tokens * top_k_num, hidden_size) after permuting
|
||
|
|
hidden_states based on topk_ids.
|
||
|
|
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||
|
|
Number of tokens assigned to each expert.
|
||
|
|
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||
|
|
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||
|
|
to determine how to handle the output.
|
||
|
|
Raises:
|
||
|
|
NotImplementedError: If the method is not implemented in the subclass.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
||
|
|
hidden_states: torch.Tensor) -> None:
|
||
|
|
"""Post-process after MLP.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
mlp_output (torch.Tensor): Tensor of shape
|
||
|
|
(num_tokens * top_k_num, hidden_size) after MLP.
|
||
|
|
hidden_states (torch.Tensor): Tensor of shape
|
||
|
|
(num_tokens, hidden_size) to be updated with the final output.
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class DummyCommImpl(MoECommMethod):
|
||
|
|
|
||
|
|
def _pre_process(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor,
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
"""Dummy implementation, see moe_comm_pre_process_fake for details."""
|
||
|
|
return moe_comm_pre_process_fake(hidden_states, topk_ids, topk_weights,
|
||
|
|
expert_map, num_experts)
|
||
|
|
|
||
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
||
|
|
hidden_states: torch.Tensor) -> None:
|
||
|
|
"""Dummy implementation that does nothing."""
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
class NativeAllGatherCommImpl(MoECommMethod):
|
||
|
|
"""This implementation should be compatible with all scenarios.
|
||
|
|
|
||
|
|
Note that this implementation purely consists of native PyTorch ops
|
||
|
|
and does not use any NPU-specific ops. So the performance may not be optimal.
|
||
|
|
But it is a good fallback for scenarios where NPU-specific ops are not available.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _pre_process(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor,
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
num_tokens = hidden_states.shape[0]
|
||
|
|
|
||
|
|
# Generate token indices and flatten
|
||
|
|
token_indices = torch.arange(num_tokens,
|
||
|
|
device=self.device,
|
||
|
|
dtype=torch.int64)
|
||
|
|
token_indices = (token_indices.unsqueeze(1).expand(
|
||
|
|
-1, self.top_k_num).reshape(-1))
|
||
|
|
|
||
|
|
# Flatten token-to-expert mappings and map to local experts
|
||
|
|
weights_flat = topk_weights.view(-1)
|
||
|
|
experts_flat = topk_ids.view(-1)
|
||
|
|
local_experts_flat = (expert_map[experts_flat]
|
||
|
|
if expert_map is not None else experts_flat)
|
||
|
|
|
||
|
|
# Filter valid token-expert pairs
|
||
|
|
mask = local_experts_flat != -1
|
||
|
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||
|
|
# So we need to filter out invalid tokens by zeroing their weights.
|
||
|
|
# This is a workaround and should be removed after the issue is fixed
|
||
|
|
filtered_weights = torch.where(mask, weights_flat,
|
||
|
|
torch.zeros_like(weights_flat)).to(
|
||
|
|
self.dtype)
|
||
|
|
filtered_experts = torch.where(
|
||
|
|
mask,
|
||
|
|
local_experts_flat,
|
||
|
|
torch.full_like(local_experts_flat, num_experts),
|
||
|
|
).to(topk_ids.dtype)
|
||
|
|
|
||
|
|
# Sort by local expert IDs
|
||
|
|
sort_indices = torch.argsort(filtered_experts.view(torch.float32))
|
||
|
|
self.sorted_token_indices = token_indices[sort_indices]
|
||
|
|
self.sorted_weights = filtered_weights[sort_indices]
|
||
|
|
|
||
|
|
# Compute token counts with minlength of num_experts
|
||
|
|
# This is equivalent to but faster than:
|
||
|
|
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
||
|
|
token_counts = torch.zeros(num_experts + 1,
|
||
|
|
device=self.device,
|
||
|
|
dtype=torch.int64)
|
||
|
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
||
|
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
||
|
|
expert_tokens = token_counts[:num_experts]
|
||
|
|
|
||
|
|
# Rearrange hidden_states
|
||
|
|
permuted_hidden_states = hidden_states[self.sorted_token_indices]
|
||
|
|
|
||
|
|
group_list_type = 1 # `count` mode
|
||
|
|
|
||
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||
|
|
|
||
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
||
|
|
hidden_states: torch.Tensor) -> None:
|
||
|
|
mlp_output = mlp_output * self.sorted_weights.unsqueeze(1)
|
||
|
|
|
||
|
|
final_hidden_states = torch.zeros_like(hidden_states)
|
||
|
|
final_hidden_states.index_add_(0, self.sorted_token_indices,
|
||
|
|
mlp_output)
|
||
|
|
|
||
|
|
hidden_states[:] = final_hidden_states
|
||
|
|
|
||
|
|
|
||
|
|
class AllGatherCommImpl(MoECommMethod):
|
||
|
|
"""This implementation is the same as NativeAllGatherCommImpl,
|
||
|
|
but uses NPU-specific ops for better performance.
|
||
|
|
|
||
|
|
This implementation should be compatible with all scenarios, and
|
||
|
|
thus it is the default implementation for MoE communication methods.
|
||
|
|
It uses `torch_npu.npu_moe_init_routing_v2` for pre-processing
|
||
|
|
and `torch_npu.npu_moe_token_unpermute` for post-processing
|
||
|
|
to handle the token-to-expert mapping and communication efficiently.
|
||
|
|
|
||
|
|
NOTE(Yizhou): TBH, it is really weird that we were supposed to use
|
||
|
|
`torch_npu.npu_moe_init_routing_v2` and `torch_npu.npu_moe_finalize_routing`
|
||
|
|
or `torch_npu.npu_moe_token_permute` and `torch_npu.npu_moe_token_unpermute`
|
||
|
|
for pre-processing and post-processing, respectively.
|
||
|
|
But `npu_moe_finalize_routing` will lead to accuracy issues so we have to
|
||
|
|
use `torch_npu.npu_moe_token_unpermute` instead.
|
||
|
|
This is a workaround and should be removed after the issue is fixed.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _pre_process(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor, # noqa: F841
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
num_tokens = hidden_states.shape[0]
|
||
|
|
|
||
|
|
self.topk_weights = topk_weights
|
||
|
|
self.topk_ids = topk_ids
|
||
|
|
|
||
|
|
first_expert_idx = 0
|
||
|
|
if expert_map is not None:
|
||
|
|
# FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
|
||
|
|
# So we need to filter out invalid tokens by zeroing their weights.
|
||
|
|
# This is a workaround and should be removed after the issue is fixed
|
||
|
|
mask = expert_map[topk_ids] != -1
|
||
|
|
# NOTE: This is equivalent to self.topk_weights[~mask] = 0.0,
|
||
|
|
# but ~mask will dispatch to aclnnNonzeroV2, which is not supported in ACL Graph
|
||
|
|
self.topk_weights = torch.where(mask, topk_weights, 0.0)
|
||
|
|
|
||
|
|
first_expert_idx = get_ep_group().rank_in_group * num_experts
|
||
|
|
last_expert_idx = first_expert_idx + num_experts
|
||
|
|
|
||
|
|
permuted_hidden_states, expanded_row_idx, expert_tokens, _ = (
|
||
|
|
torch_npu.npu_moe_init_routing_v2(
|
||
|
|
hidden_states,
|
||
|
|
topk_ids,
|
||
|
|
active_num=num_tokens * self.top_k_num,
|
||
|
|
expert_num=self.global_num_experts,
|
||
|
|
expert_tokens_num_type=1, # Only support `count` mode now
|
||
|
|
expert_tokens_num_flag=True, # Output `expert_tokens`
|
||
|
|
active_expert_range=[first_expert_idx, last_expert_idx],
|
||
|
|
quant_mode=-1,
|
||
|
|
))
|
||
|
|
self.expanded_row_idx = expanded_row_idx
|
||
|
|
permuted_hidden_states = permuted_hidden_states
|
||
|
|
|
||
|
|
group_list_type = 1 # `count` mode
|
||
|
|
|
||
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||
|
|
|
||
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
||
|
|
hidden_states: torch.Tensor) -> None:
|
||
|
|
hidden_states[:] = torch_npu.npu_moe_token_unpermute(
|
||
|
|
permuted_tokens=mlp_output,
|
||
|
|
sorted_indices=self.expanded_row_idx,
|
||
|
|
probs=self.topk_weights)
|
||
|
|
|
||
|
|
|
||
|
|
class MC2CommImpl(MoECommMethod):
|
||
|
|
"""This implementation is for the scenarios listed below:
|
||
|
|
1. `enable_expert_parallel=True`.
|
||
|
|
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||
|
|
3. `enable_expert_parallel=False` is not supported.
|
||
|
|
|
||
|
|
This implementation uses the MC2 communication method, which is optimized for
|
||
|
|
Communication and Computation parallelism on Ascend devices.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
device: torch.device,
|
||
|
|
dtype: torch.dtype,
|
||
|
|
hf_config: PretrainedConfig,
|
||
|
|
):
|
||
|
|
super().__init__(device, dtype, hf_config)
|
||
|
|
|
||
|
|
# Shared communication configurations
|
||
|
|
ep_group = get_mc2_group()
|
||
|
|
self.ep_rank_id = ep_group.rank_in_group
|
||
|
|
self.ep_world_size = ep_group.world_size
|
||
|
|
self.tp_world_size = get_tp_group().world_size
|
||
|
|
|
||
|
|
device_group = ep_group.device_group
|
||
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
||
|
|
backend = device_group._get_backend(torch.device("npu"))
|
||
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||
|
|
|
||
|
|
# Feature flags
|
||
|
|
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||
|
|
"npu_moe_distribute_dispatch_v2")
|
||
|
|
self.is_ascend_a3 = get_ascend_soc_version() == AscendSocVersion.A3
|
||
|
|
self.need_extra_args = self.is_ascend_a3 # or is_torchair
|
||
|
|
|
||
|
|
# Intermediate tensors to be passed from pre_process to post_process
|
||
|
|
self.topk_ids = None
|
||
|
|
self.topk_weights = None
|
||
|
|
self.mc2_mask = None
|
||
|
|
self.assist_info_for_combine = None
|
||
|
|
self.ep_recv_counts = None
|
||
|
|
self.tp_recv_counts = None
|
||
|
|
|
||
|
|
def _pre_process(
|
||
|
|
self,
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor,
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
# Store tensors needed for post_process
|
||
|
|
self.topk_ids = topk_ids
|
||
|
|
self.topk_weights = topk_weights.to(torch.float32)
|
||
|
|
self.mc2_mask = get_forward_context().mc2_mask
|
||
|
|
|
||
|
|
dispatch_kwargs = {
|
||
|
|
"x": hidden_states,
|
||
|
|
"expert_ids": self.topk_ids,
|
||
|
|
"expert_shard_type": 0,
|
||
|
|
"shared_expert_rank_num": 0,
|
||
|
|
"moe_expert_num": self.global_num_experts,
|
||
|
|
"global_bs": 0,
|
||
|
|
"scales": None,
|
||
|
|
"quant_mode": 0,
|
||
|
|
"group_ep": self.moe_all_to_all_group_name,
|
||
|
|
"ep_world_size": self.ep_world_size,
|
||
|
|
"ep_rank_id": self.ep_rank_id,
|
||
|
|
}
|
||
|
|
|
||
|
|
if self.need_extra_args:
|
||
|
|
dispatch_kwargs.update({
|
||
|
|
"group_tp": self.moe_all_to_all_group_name,
|
||
|
|
"tp_world_size": 1,
|
||
|
|
"tp_rank_id": 0,
|
||
|
|
})
|
||
|
|
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||
|
|
dispatch_kwargs.update({
|
||
|
|
"x_active_mask": self.mc2_mask,
|
||
|
|
})
|
||
|
|
|
||
|
|
dispatch = torch_npu.npu_moe_distribute_dispatch_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch
|
||
|
|
|
||
|
|
(
|
||
|
|
permuted_hidden_states,
|
||
|
|
_, # dynamic_scale is not used
|
||
|
|
self.assist_info_for_combine,
|
||
|
|
expert_tokens,
|
||
|
|
self.ep_recv_counts,
|
||
|
|
self.tp_recv_counts,
|
||
|
|
) = dispatch(**dispatch_kwargs)[:6]
|
||
|
|
|
||
|
|
group_list_type = 1
|
||
|
|
|
||
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||
|
|
|
||
|
|
def _post_process(self, mlp_output: torch.Tensor,
|
||
|
|
hidden_states: torch.Tensor) -> None:
|
||
|
|
combine_kwargs = {
|
||
|
|
"expand_x": mlp_output,
|
||
|
|
"expert_ids": self.topk_ids,
|
||
|
|
"expert_scales": self.topk_weights,
|
||
|
|
"expert_shard_type": 0,
|
||
|
|
"shared_expert_rank_num": 0,
|
||
|
|
"moe_expert_num": self.global_num_experts,
|
||
|
|
"global_bs": 0,
|
||
|
|
"ep_send_counts": self.ep_recv_counts,
|
||
|
|
"group_ep": self.moe_all_to_all_group_name,
|
||
|
|
"ep_world_size": self.ep_world_size,
|
||
|
|
"ep_rank_id": self.ep_rank_id,
|
||
|
|
}
|
||
|
|
|
||
|
|
if self.enable_dispatch_v2:
|
||
|
|
combine_kwargs[
|
||
|
|
"assist_info_for_combine"] = self.assist_info_for_combine
|
||
|
|
else:
|
||
|
|
combine_kwargs["expand_idx"] = self.assist_info_for_combine
|
||
|
|
|
||
|
|
if self.need_extra_args:
|
||
|
|
combine_kwargs.update({
|
||
|
|
"tp_send_counts": self.tp_recv_counts,
|
||
|
|
"group_tp": self.moe_all_to_all_group_name,
|
||
|
|
"tp_world_size": 1,
|
||
|
|
"tp_rank_id": 0,
|
||
|
|
})
|
||
|
|
if self.is_ascend_a3 and self.enable_dispatch_v2:
|
||
|
|
combine_kwargs.update({
|
||
|
|
"x_active_mask": self.mc2_mask,
|
||
|
|
})
|
||
|
|
|
||
|
|
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||
|
|
|
||
|
|
hidden_states[:] = combine(**combine_kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
def moe_comm_pre_process(
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor,
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
"""This function is a wrapper for the pre_process method of the
|
||
|
|
MoECommMethod instance stored in the ForwardContext. So it can be
|
||
|
|
used as a custom op in the vllm framework.
|
||
|
|
"""
|
||
|
|
forward_context: ForwardContext = get_forward_context()
|
||
|
|
self = forward_context.moe_comm_method
|
||
|
|
return self._pre_process(hidden_states, topk_ids, topk_weights, expert_map,
|
||
|
|
num_experts)
|
||
|
|
|
||
|
|
|
||
|
|
def moe_comm_pre_process_fake(
|
||
|
|
hidden_states: torch.Tensor,
|
||
|
|
topk_ids: torch.Tensor,
|
||
|
|
topk_weights: torch.Tensor,
|
||
|
|
expert_map: torch.Tensor,
|
||
|
|
num_experts: int,
|
||
|
|
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||
|
|
"""This is a fake implementation of the pre_process method.
|
||
|
|
torch.compile will use this implementation to generate FX graph.
|
||
|
|
"""
|
||
|
|
top_k_num = topk_ids.shape[1]
|
||
|
|
permuted_hidden_states = hidden_states.repeat_interleave(top_k_num, dim=0)
|
||
|
|
expert_tokens = torch.zeros((num_experts, ),
|
||
|
|
dtype=torch.int64,
|
||
|
|
device=hidden_states.device)
|
||
|
|
group_list_type = 0
|
||
|
|
return permuted_hidden_states, expert_tokens, group_list_type
|
||
|
|
|
||
|
|
|
||
|
|
def moe_comm_post_process(mlp_output: torch.Tensor,
|
||
|
|
hidden_states: torch.Tensor) -> None:
|
||
|
|
"""This function is a wrapper for the post_process method of the
|
||
|
|
MoECommMethod instance stored in the ForwardContext. So it can be
|
||
|
|
used as a custom op in the vllm framework.
|
||
|
|
"""
|
||
|
|
forward_context: ForwardContext = get_forward_context()
|
||
|
|
self = forward_context.moe_comm_method
|
||
|
|
self._post_process(mlp_output, hidden_states)
|
||
|
|
return
|
||
|
|
|
||
|
|
|
||
|
|
direct_register_custom_op(
|
||
|
|
op_name="moe_comm_pre_process",
|
||
|
|
op_func=moe_comm_pre_process,
|
||
|
|
mutates_args=[],
|
||
|
|
fake_impl=moe_comm_pre_process_fake,
|
||
|
|
dispatch_key="PrivateUse1",
|
||
|
|
)
|
||
|
|
|
||
|
|
direct_register_custom_op(
|
||
|
|
op_name="moe_comm_post_process",
|
||
|
|
op_func=moe_comm_post_process,
|
||
|
|
mutates_args=["hidden_states"],
|
||
|
|
fake_impl=lambda x, y: None, # No-op for fake implementation
|
||
|
|
dispatch_key="PrivateUse1",
|
||
|
|
)
|