[Refactor][MoE] remove redundant code after refactoring fused_moe (#2612)
### What this PR does / why we need it?
There are a lot of redundant codes related to moe here, and the
structure is not very clear.
We did the following things:
we have placed the relatively independent code related to apply_mlp into
a separate file;
removed the environment variables of alltoall_buffer and alltoall_seq.
Remove the code related to alltoall_buffer and alltoall_seq, and retain
the sole TokenDispatcher inheritance class.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e&ut
- vLLM version: v0.10.1.1
- vLLM main:
4071c76cf3
---------
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,8 @@ import torch_npu
|
||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod)
|
||||
|
||||
@@ -30,7 +32,6 @@ from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import fused_experts_moge
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
setup_token_dispatchers
|
||||
@@ -139,6 +140,95 @@ def fused_experts(
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
@@ -46,397 +45,12 @@ from vllm_ascend.distributed.communication_op import \
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
|
||||
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
|
||||
from vllm_ascend.ops.layers.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
max_row_per_ep_rank: int, num_tokens: int,
|
||||
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
original_total_elements = num_tokens * top_k
|
||||
device = topk_ids.device
|
||||
original_dtype = topk_ids.dtype
|
||||
|
||||
if original_total_elements == 0:
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
unpad_indices = torch.full((original_total_elements, ),
|
||||
-1,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
experts_per_ep_rank_val = expert_num // ep_size
|
||||
if experts_per_ep_rank_val == 0:
|
||||
raise ValueError(
|
||||
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
|
||||
"Ensure expert_num >= ep_size.")
|
||||
|
||||
assigned_ep_rank = (topk_ids.float() /
|
||||
experts_per_ep_rank_val).to(original_dtype)
|
||||
indices_arange = torch.arange(topk_ids.shape[0], device=device)
|
||||
|
||||
is_new_segment = torch.cat(
|
||||
(torch.tensor([True], device=device), assigned_ep_rank[1:]
|
||||
!= assigned_ep_rank[:-1]))
|
||||
temp_start_markers = torch.full_like(indices_arange,
|
||||
-1,
|
||||
dtype=indices_arange.dtype)
|
||||
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
|
||||
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
|
||||
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
|
||||
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
|
||||
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
|
||||
indices_in_rec_cond_list_for_all = cumsum_kept - 1
|
||||
unpad_indices = torch.where(
|
||||
is_kept_mask, indices_in_rec_cond_list_for_all,
|
||||
torch.tensor(-1, device=device, dtype=torch.long))
|
||||
output_len = ep_size * max_row_per_ep_rank
|
||||
topk_ids_pad = torch.full((output_len, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
if topk_ids.shape[0] > 0:
|
||||
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
|
||||
temp_pad_buffer = torch.full((output_len + 1, ),
|
||||
expert_num,
|
||||
dtype=original_dtype,
|
||||
device=device)
|
||||
output_len_tensor = torch.tensor(output_len,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
|
||||
output_len_tensor)
|
||||
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
|
||||
topk_ids_pad = temp_pad_buffer[:output_len]
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
|
||||
def apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
Args:
|
||||
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
||||
w1: expert weights1 with shape
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: expert weights2 with shape
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
group_list: number of tokens for each expert, follow cumsum mode, and
|
||||
with shape (num_experts).
|
||||
transpose_weight:
|
||||
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
||||
(num_experts, hidden_size, intermediate_size * 2)
|
||||
w2: (num_experts, hidden_size, intermediate_size) ->
|
||||
(num_experts, intermediate_size, hidden_size)
|
||||
|
||||
Returns:
|
||||
hidden_states: output hidden states after MLP.
|
||||
"""
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
|
||||
|
||||
def unified_fused_experts_eager(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -742,24 +356,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
self.token_dispatcher = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
|
||||
self.quant_method, AscendUnquantizedFusedMoEMethod):
|
||||
self.reduce_results = False
|
||||
moe_dispatcher_config = (
|
||||
MoEDispatcherConfig().set_num_moe_experts(
|
||||
self.global_num_experts).set_num_local_experts(
|
||||
self.local_num_experts).set_moe_router_topk(
|
||||
top_k).set_group_topk(topk_group).
|
||||
set_num_groups(num_expert_group).set_expert_bias(
|
||||
e_score_correction_bias).set_scaling_factor(1.0).build())
|
||||
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
|
||||
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
|
||||
moe_dispatcher_config)
|
||||
self.token_dispatchers = [
|
||||
self.token_dispatcher, token_dispatcher1
|
||||
]
|
||||
|
||||
ep_size = (get_ep_group().world_size if
|
||||
vllm_config.parallel_config.enable_expert_parallel else 1)
|
||||
|
||||
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
@@ -29,434 +29,11 @@ import torch_npu
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.distributed.tensor_parallel import (
|
||||
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
|
||||
all_to_all_sp2hp, gather_from_sequence_parallel_region,
|
||||
reduce_scatter_last_dim_to_tensor_parallel_region)
|
||||
from vllm_ascend.distributed.tensor_parallel import \
|
||||
gather_from_sequence_parallel_region
|
||||
from vllm_ascend.ops.comm_utils import async_all_to_all
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
|
||||
|
||||
class MoEDispatcherConfig:
|
||||
|
||||
def __init__(self):
|
||||
self.num_local_experts: int = 0
|
||||
self.num_moe_experts: int = 0
|
||||
self.moe_pad_expert_input_to_capacity: bool = False
|
||||
self.moe_expert_capacity_factor: Optional[float] = None
|
||||
self.moe_router_topk: int = 2
|
||||
self.moe_grouped_gemm: bool = False
|
||||
self.group_topk: int = 0
|
||||
self.num_groups: int = 1
|
||||
self.expert_bias: torch.Tensor = None
|
||||
self.scaling_factor: Optional[float] = None
|
||||
self.is_fused: bool = True
|
||||
|
||||
def set_num_local_experts(self, num_local_experts):
|
||||
self.num_local_experts = num_local_experts
|
||||
return self
|
||||
|
||||
def set_num_moe_experts(self, num_moe_experts):
|
||||
self.num_moe_experts = num_moe_experts
|
||||
return self
|
||||
|
||||
def set_moe_pad_expert_input_to_capacity(self,
|
||||
moe_pad_expert_input_to_capacity):
|
||||
self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity
|
||||
return self
|
||||
|
||||
def set_moe_expert_capacity_factor(self, moe_expert_capacity_factor):
|
||||
self.moe_expert_capacity_factor = moe_expert_capacity_factor
|
||||
return self
|
||||
|
||||
def set_moe_router_topk(self, moe_router_topk):
|
||||
self.moe_router_topk = moe_router_topk
|
||||
return self
|
||||
|
||||
def set_moe_grouped_gemm(self, moe_grouped_gemm):
|
||||
self.moe_grouped_gemm = moe_grouped_gemm
|
||||
return self
|
||||
|
||||
def set_group_topk(self, group_topk):
|
||||
self.group_topk = group_topk
|
||||
return self
|
||||
|
||||
def set_num_groups(self, num_groups):
|
||||
self.num_groups = num_groups
|
||||
return self
|
||||
|
||||
def set_expert_bias(self, expert_bias):
|
||||
self.expert_bias = expert_bias
|
||||
return self
|
||||
|
||||
def set_scaling_factor(self, scaling_factor):
|
||||
self.scaling_factor = scaling_factor
|
||||
return self
|
||||
|
||||
def set_is_fused(self, is_fused):
|
||||
self.is_fused = is_fused
|
||||
return self
|
||||
|
||||
def build(self):
|
||||
return self
|
||||
|
||||
|
||||
class MoEDispatcher:
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
"""
|
||||
self.config = config
|
||||
self.shared_experts = None
|
||||
|
||||
def set_shared_experts(self, shared_experts):
|
||||
self.shared_experts = shared_experts
|
||||
|
||||
@property
|
||||
def ep_group(self):
|
||||
"""Get expert model parallel group."""
|
||||
return get_ep_group().device_group
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return get_ep_group().rank_in_group
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return get_ep_group().world_size
|
||||
|
||||
@property
|
||||
def tp_ep_group(self):
|
||||
"""Get expert tensor and model parallel group."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def tp_ep_size(self):
|
||||
return 1
|
||||
|
||||
|
||||
class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
|
||||
overlap_stream = None
|
||||
"""
|
||||
The implementation of the AlltoAll-based token dispatcher, which handles token
|
||||
dispatching on the sequence level instead of token level. The core of this implementation
|
||||
lies in each device dispatching on the entire sequence, with the hidden state being partitioned.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config: MoEDispatcherConfig):
|
||||
"""
|
||||
Initialize the AlltoAllSeq token dispatcher.
|
||||
|
||||
Args:
|
||||
config (MoEDispatcherConfig): Configuration for the transformer model.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.num_local_experts = config.num_local_experts
|
||||
self.config = config
|
||||
# use MOEAlltoAllSEQTokenDispatcher to init
|
||||
|
||||
self.hidden_shape = None
|
||||
self.num_input_tokens = None
|
||||
self.num_experts = config.num_moe_experts
|
||||
assert self.num_local_experts > 0, "Expected at least one expert"
|
||||
if self.num_local_experts > 1:
|
||||
self.expert_ids_per_ep_rank = torch.tensor(
|
||||
[i % self.num_local_experts for i in range(self.num_experts)],
|
||||
dtype=torch.int32,
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
self.probs = None
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.routing_map = None
|
||||
self.hidden_shape_before_permute = None
|
||||
|
||||
# [tp_ep_size * ep_size, num_local_experts]. Represents the number of tokens sent
|
||||
# to each local expert by all ranks.
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
|
||||
# A cuda stream synchronization is needed in self.token_permutation()
|
||||
# in some cases, because there are several non-blocking DtoH data
|
||||
# transfers called in self.preprocess(). The synchronization happens
|
||||
# at different points based on MoE settings as late as possible.
|
||||
# Valid sync points are "before_permutation_1", "before_ep_alltoall",
|
||||
# "before_finish", and "no_sync".
|
||||
self.device_sync_point = "no_sync"
|
||||
|
||||
# cached intermediate tensors.
|
||||
self.cached_permutated_local_input_tokens = None
|
||||
self.cached_global_input_tokens = None
|
||||
self.cached_shared_expert_output = None
|
||||
self.tokens_per_expert = None
|
||||
self.perm1_finish_event = None
|
||||
self.global_input_tokens_local_experts_indices = None
|
||||
|
||||
if MoEAlltoAllSeqOverLapDispatcher.overlap_stream is None:
|
||||
MoEAlltoAllSeqOverLapDispatcher.overlap_stream = torch.npu.Stream()
|
||||
|
||||
self.overlap_stream = MoEAlltoAllSeqOverLapDispatcher.overlap_stream
|
||||
|
||||
def preprocess(self,
|
||||
indices: torch.Tensor,
|
||||
with_sync=True) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess routing map for AlltoAll communication and token permutation.
|
||||
This method computes the number of tokens assigned to each expert based on
|
||||
the routing map. It also initializes the necessary data structures for
|
||||
AlltoAll communication, such as input and output splits, and the mapping
|
||||
between global tokens and local experts.
|
||||
|
||||
Args:
|
||||
routing_map (torch.Tensor): The mapping of tokens to experts, with shape
|
||||
[num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor containing the number of tokens assigned to local expert.
|
||||
"""
|
||||
num_local_tokens_per_expert = torch.histc(indices,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
|
||||
# num_local_tokens_per_expert: [num_experts]
|
||||
|
||||
ep_size = self.ep_size
|
||||
|
||||
# Dropless
|
||||
self.num_out_tokens = indices.numel()
|
||||
if self.ep_size > 1 or self.num_local_experts > 1:
|
||||
# Token dropless and enable ep. A synchronization is needed before expert parallel
|
||||
# AlltoAll communication to get the `input_splits` and `output_splits` CPU values.
|
||||
self.device_sync_point = "before_ep_alltoall"
|
||||
else:
|
||||
# Token dropless and no ep. A synchronization is needed to get the
|
||||
# `tokens_per_expert` CPU value.
|
||||
self.device_sync_point = "before_finish"
|
||||
|
||||
if ep_size > 1:
|
||||
# ===================================================
|
||||
# Calculate input_splits, output_splits for alltoall-v.
|
||||
# ===================================================
|
||||
self.input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size, self.num_local_experts).sum(axis=1).to(
|
||||
torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
self.num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum."
|
||||
)
|
||||
self.output_splits = (self.num_global_tokens_per_local_expert.sum(
|
||||
axis=-1).to(torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = self.num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
# ===================================================
|
||||
# num_global_tokens_per_expert: [ep_size, num_experts]
|
||||
# num_global_tokens_per_local_expert: [ep_size, num_local_experts]
|
||||
# num_tokens_per_local_expert: [num_local_experts]
|
||||
# ===================================================
|
||||
else:
|
||||
self.num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(
|
||||
-1, self.num_experts)
|
||||
num_tokens_per_local_expert = num_local_tokens_per_expert
|
||||
|
||||
if self.num_local_experts > 1 and with_sync:
|
||||
if self.num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
self.device_sync_point = "no_sync"
|
||||
self.global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
self.num_global_tokens_per_local_expert.ravel())
|
||||
|
||||
return num_tokens_per_local_expert
|
||||
|
||||
def token_permutation(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
probs: torch.Tensor,
|
||||
routing_map: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Dispatch tokens to local experts using AlltoAllSeq communication.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input token embeddings.
|
||||
probs (torch.Tensor): Probs of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
routing_map (torch.Tensor): Mapping of tokens assigned to experts.
|
||||
Shape: [num_tokens, num_experts].
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]:
|
||||
- Permuted token embeddings for local experts.
|
||||
- Number of tokens per expert.
|
||||
"""
|
||||
self.hidden_shape = hidden_states.shape
|
||||
self.probs = probs
|
||||
self.top_indices = routing_map
|
||||
assert probs.dim() == 2, "Expected 2D tensor for probs"
|
||||
assert routing_map.dim() == 2, "Expected 2D tensor for routing map"
|
||||
|
||||
# Permutation 1: input to AlltoAll input
|
||||
def alltoall_token_permutation1(hidden_states, routing_map):
|
||||
assert self.hidden_shape is not None
|
||||
hidden_states = hidden_states.view(-1, self.hidden_shape[-1])
|
||||
tokens_per_expert = self.preprocess(routing_map)
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = all_to_all_sp2hp(hidden_states,
|
||||
group=self.tp_ep_group)
|
||||
self.hidden_shape_before_permute = hidden_states.shape
|
||||
|
||||
if self.device_sync_point == "before_permutation_1":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
tokens=hidden_states,
|
||||
indices=self.top_indices,
|
||||
num_out_tokens=self.num_out_tokens,
|
||||
)
|
||||
return permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert
|
||||
|
||||
permutated_local_input_tokens, reversed_local_input_permutation_mapping, tokens_per_expert = alltoall_token_permutation1(
|
||||
hidden_states, routing_map)
|
||||
self.reversed_local_input_permutation_mapping = reversed_local_input_permutation_mapping
|
||||
# permute 1
|
||||
|
||||
ep_group = self.ep_group
|
||||
|
||||
# Perform expert parallel AlltoAll communication
|
||||
if self.device_sync_point == "before_ep_alltoall":
|
||||
torch.npu.current_stream().synchronize()
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens,
|
||||
self.output_splits,
|
||||
self.input_splits,
|
||||
ep_group,
|
||||
)
|
||||
|
||||
# shared experts compute
|
||||
if self.shared_experts is not None:
|
||||
(share_experts_output), *_ = self.shared_experts(hidden_states)
|
||||
else:
|
||||
share_experts_output = None
|
||||
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_permutation2(global_input_tokens):
|
||||
# Permutation 2: Sort tokens by local expert.
|
||||
if self.num_local_experts > 1:
|
||||
global_input_tokens, self.reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens,
|
||||
self.global_input_tokens_local_experts_indices)
|
||||
|
||||
# Perform tensor parallel AllGather on the hidden dimension to obtain the input tokens.
|
||||
# global_input_tokens: [SEQL, H/TP] -> [SEQL, H]
|
||||
if self.tp_ep_size > 1 and self.config.moe_grouped_gemm:
|
||||
global_input_tokens = all_gather_last_dim_from_tensor_parallel_region(
|
||||
global_input_tokens, self.tp_ep_group)
|
||||
if self.device_sync_point == "before_finish":
|
||||
torch.npu.current_stream().synchronize()
|
||||
|
||||
return global_input_tokens
|
||||
|
||||
# token premute2 input
|
||||
global_input_tokens = alltoall_token_permutation2(global_input_tokens)
|
||||
|
||||
return share_experts_output, global_input_tokens, tokens_per_expert
|
||||
|
||||
def token_unpermutation(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
"""
|
||||
Reverse the token permutation to restore the original order.
|
||||
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Output from local experts.
|
||||
bias (torch.Tensor, optional): Bias tensor (not supported).
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
- Unpermuted token embeddings in the original order.
|
||||
- None (bias is not supported).
|
||||
"""
|
||||
|
||||
def alltoall_token_unpermutation1(hidden_states):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
|
||||
# Perform tensor parallel Reduce-Scatter
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
if self.tp_ep_size > 1:
|
||||
hidden_states = reduce_scatter_last_dim_to_tensor_parallel_region(
|
||||
hidden_states, group=self.tp_ep_group)
|
||||
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states,
|
||||
self.reversed_global_input_permutation_mapping)
|
||||
|
||||
return hidden_states
|
||||
|
||||
hidden_states = alltoall_token_unpermutation1(hidden_states)
|
||||
|
||||
ep_group = self.ep_group
|
||||
# Perform expert parallel AlltoAll communication
|
||||
# hidden_states: [SEQL, H] -> [SEQL, H/TP]
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
hidden_states, self.input_splits, self.output_splits, ep_group)
|
||||
handle.wait()
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
def alltoall_token_unpermutation2(permutated_local_input_tokens):
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=self.reversed_local_input_permutation_mapping.
|
||||
to(torch.int32),
|
||||
probs=self.probs,
|
||||
restore_shape=self.hidden_shape_before_permute)
|
||||
|
||||
# Perform tensor parallel AlltoAll communication
|
||||
# output: [S*B, H/TP] -> [S*B/TP, H]
|
||||
if self.tp_ep_size > 1:
|
||||
output = all_to_all_hp2sp(output, self.tp_ep_group)
|
||||
|
||||
# Reshape the output tensor
|
||||
output = output.view(self.hidden_shape)
|
||||
return output
|
||||
|
||||
output = alltoall_token_unpermutation2(permutated_local_input_tokens)
|
||||
|
||||
self.input_splits = None
|
||||
self.output_splits = None
|
||||
self.num_global_tokens_per_local_expert = None
|
||||
self.num_global_tokens_per_local_expert_cpu = None
|
||||
|
||||
return output, None
|
||||
|
||||
|
||||
_Dispatchers: Dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -1090,7 +667,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
bias: torch.Tensor = None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllSeqTokenDispatcher"
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
hidden_states = self._combine_preprocess(hidden_states)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user