Files
xc-llm-ascend/vllm_ascend/ops/fused_moe.py
HongtaoYang dcd0005058 [Fix] Remove npu_group_topk before CANN version update (#242)
Remove npu_group_topk before CANN version update.

Signed-off-by: SidaoY <1024863041@qq.com>
2025-03-06 09:02:46 +08:00

190 lines
7.3 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
import torch
import torch_npu
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
def group_topk(hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: Optional[int] = 0,
topk_group: Optional[int] = 0,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
if scoring_func == "softmax":
scores = torch.softmax(gating_output, dim=-1)
elif scoring_func == "sigmoid":
scores = gating_output.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_scores = scores
scores = scores + e_score_correction_bias.unsqueeze(0)
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
# TODO: Replace this piece of code to npu_group_topk when CANN and NNAL version is update
num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group,
-1).max(dim=-1).values
group_idx = torch.topk(group_scores.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
score_mask = group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
scores.shape[-1] // num_expert_group).reshape(num_token, -1)
scores = scores.masked_fill(~score_mask.bool(), 0.0)
if e_score_correction_bias is not None:
topk_ids = torch.topk(scores, k=topk, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_scores.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(scores,
k=topk,
dim=-1,
sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids.to(torch.int32)
def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor,
w2: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, top_k: int):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
ori_shape = hidden_states.shape
if len(ori_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(top_k, -1).permute(
1, 0).contiguous()
expanded_x, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, E)
expert_tokens = expert_tokens.to(torch.int64)
w1 = w1.transpose(1, 2)
gate_up_out_list = torch_npu.npu_grouped_matmul(x=[expanded_x],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens)
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens)
down_out_list = torch.cat(down_out_list, dim=0)
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
hidden_states = torch_npu.npu_moe_finalize_routing(
down_out_list,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids)
if len(ori_shape) == 3:
hidden_states = hidden_states.view(ori_shape)
return hidden_states
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
topk_weights, topk_ids = group_topk(
hidden_states=x,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k)
UnquantizedFusedMoEMethod.forward_oot = forward_oot