forked from EngineX-Ascend/enginex-ascend-910-vllm
v0.10.1rc1
This commit is contained in:
0
vllm_ascend/ops/layers/__init__.py
Normal file
0
vllm_ascend/ops/layers/__init__.py
Normal file
283
vllm_ascend/ops/layers/experts_selector.py
Normal file
283
vllm_ascend/ops/layers/experts_selector.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
|
||||
def return_row_idx(hidden_states, top_k):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous())
|
||||
return row_idx
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
global_num_experts: int = -1):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
|
||||
Args:
|
||||
router_logits: router logits of shape (num_tokens, hidden_size).
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
top_k: number of top k experts.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
indices_type: dtype of indices
|
||||
global_num_experts: Global number of experts.
|
||||
|
||||
Returns:
|
||||
topk_weights: router weights of shape (num_tokens, top_k).
|
||||
topk_ids: selected expert IDs of shape (num_tokens, top_k).
|
||||
"""
|
||||
|
||||
topk_weights, topk_ids, row_idx = _select_experts_with_fusion_ops(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
renormalize=renormalize,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_weights is None:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
if row_idx is None:
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
return topk_weights, topk_ids, row_idx
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _renormalize_topk_weights(
|
||||
topk_weights: torch.Tensor,
|
||||
renormalize: bool,
|
||||
):
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
return topk_weights
|
||||
|
||||
|
||||
def _select_expert_use_group_topk(
|
||||
topk_weights: torch.Tensor, topk_group: Optional[int],
|
||||
renormalize: bool, top_k: int, num_expert_group: Optional[int],
|
||||
e_score_correction_bias: Optional[torch.Tensor]):
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
if e_score_correction_bias is not None:
|
||||
# Store original scores before applying correction bias. We use biased
|
||||
# scores for expert selection but original scores for routing weights
|
||||
original_weights = topk_weights
|
||||
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
topk_weights, topk_ids, row_idx = None, None, None
|
||||
# NOTE: now npu_moe_gating_top_k can only support 'group_count=256' pattern
|
||||
is_deepseek_v3_r1 = global_num_experts == 256
|
||||
if is_deepseek_v3_r1:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k, # topk currently 8
|
||||
bias=e_score_correction_bias,
|
||||
k_group=topk_group, # fix: 4
|
||||
group_count=num_expert_group, # fix 8
|
||||
group_select_mode=
|
||||
1, # 0: the maximum in the group; 1: topk2.sum(fix)
|
||||
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
||||
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
||||
# out_flag=False, # todo new api; should the third output be output
|
||||
# y2_flag=False, # old api; should the third output be output
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20))
|
||||
row_idx = return_row_idx(hidden_states, top_k)
|
||||
if not use_grouped_topk and custom_routing_function is None and scoring_func == "softmax":
|
||||
topk_weights, topk_ids, row_idx = torch_npu.npu_moe_gating_top_k_softmax(
|
||||
x=router_logits, finished=None, k=top_k)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids, row_idx
|
||||
|
||||
|
||||
def _native_select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
router_logits: Router logits of shape (num_tokens, num_experts).
|
||||
top_k: Number of experts to select.
|
||||
use_grouped_topk: Whether to group experts before selecting top-k.
|
||||
renormalize: Whether to renormalize the routing weights.
|
||||
topk_group: Number of expert groups to select from.
|
||||
num_expert_group: Number of experts in each group.
|
||||
custom_routing_function: Custom routing function.
|
||||
scoring_func: Scoring function to use.
|
||||
e_score_correction_bias: Correction bias to apply to expert scores.
|
||||
|
||||
Returns:
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported scoring function is provided.
|
||||
"""
|
||||
|
||||
if scoring_func == "softmax":
|
||||
topk_weights = router_logits.softmax(dim=-1)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights = router_logits.sigmoid()
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
if use_grouped_topk:
|
||||
return _select_expert_use_group_topk(
|
||||
topk_weights=topk_weights,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
if custom_routing_function is not None:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
hidden_states=hidden_states,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
||||
topk_weights = topk_weights.to(hidden_states.dtype)
|
||||
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
199
vllm_ascend/ops/layers/moe_mlp.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.utils import dispose_tensor, is_310p
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
||||
if dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
bias1, bias2 = None, None
|
||||
_output_dtype = w2_scale.dtype
|
||||
|
||||
is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2
|
||||
if w1_scale_bias is None and is_mc2:
|
||||
w1_scale = w1_scale.to(torch.float32)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias]
|
||||
bias2 = [w2_scale_bias]
|
||||
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
||||
_output_dtype = torch.bfloat16
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
scale=[w1_scale],
|
||||
bias=bias1,
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w2],
|
||||
scale=[w2_scale],
|
||||
bias=bias2,
|
||||
per_token_scale=[swiglu_out_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
w1 = w1.transpose(1, 2)
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
|
||||
if topk_scales is not None:
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
w2 = w2.transpose(1, 2)
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False) -> torch.Tensor:
|
||||
if with_quant:
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales)
|
||||
Reference in New Issue
Block a user