### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
429 lines
17 KiB
Python
429 lines
17 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch_npu
|
|
from vllm.distributed import get_ep_group
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
|
|
|
|
|
|
def apply_mlp_decode(hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
group_list: torch.Tensor,
|
|
dynamic_scale: torch.Tensor = None,
|
|
group_list_type: int = 1) -> torch.Tensor:
|
|
"""
|
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
|
Args:
|
|
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
|
|
w1: expert weights1 with shape
|
|
(num_experts, hidden_size, intermediate_size * 2)
|
|
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
|
w2: expert weights2 with shape
|
|
(num_experts, intermediate_size, hidden_size)
|
|
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
|
group_list: number of tokens for each expert, follow cumsum mode, and
|
|
with shape (num_experts).
|
|
transpose_weight:
|
|
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
|
(num_experts, hidden_size, intermediate_size * 2)
|
|
w2: (num_experts, hidden_size, intermediate_size) ->
|
|
(num_experts, intermediate_size, hidden_size)
|
|
Returns:
|
|
hidden_states: output hidden states after MLP.
|
|
"""
|
|
|
|
if dynamic_scale is None:
|
|
unquantized_hidden_states = hidden_states
|
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
|
hidden_states)
|
|
# Dispose the original unquantized hidden states
|
|
# to save npu memory because they're no longer used.
|
|
dispose_tensor(unquantized_hidden_states)
|
|
else:
|
|
pertoken_scale = dynamic_scale
|
|
|
|
# gmm1: gate_up_proj
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
x=[hidden_states],
|
|
weight=[w1],
|
|
split_item=3,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=group_list,
|
|
output_dtype=torch.int32)[0]
|
|
|
|
# act_fn: swiglu
|
|
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
|
x=hidden_states,
|
|
weight_scale=w1_scale,
|
|
activation_scale=pertoken_scale,
|
|
bias=None,
|
|
quant_scale=None,
|
|
quant_offset=None,
|
|
group_index=group_list,
|
|
activate_left=True,
|
|
quant_mode=1,
|
|
)
|
|
|
|
# gmm2: down_proj
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
x=[hidden_states],
|
|
weight=[w2],
|
|
scale=[w2_scale],
|
|
per_token_scale=[swiglu_out_scale],
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=group_list,
|
|
output_dtype=w2_scale.dtype)[0]
|
|
return hidden_states
|
|
|
|
|
|
def apply_mlp(hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
group_list: torch.Tensor,
|
|
dynamic_scale: torch.Tensor = None,
|
|
group_list_type: int = 1,
|
|
w1_scale_bias: torch.Tensor = None,
|
|
w2_scale_bias: torch.Tensor = None) -> torch.Tensor:
|
|
"""
|
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
|
|
|
Args:
|
|
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
|
w1: expert weights1 with shape
|
|
(num_experts, hidden_size, intermediate_size * 2)
|
|
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
|
w2: expert weights2 with shape
|
|
(num_experts, intermediate_size, hidden_size)
|
|
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
|
group_list: number of tokens for each expert, follow cumsum mode, and
|
|
with shape (num_experts).
|
|
transpose_weight:
|
|
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
|
(num_experts, hidden_size, intermediate_size * 2)
|
|
w2: (num_experts, hidden_size, intermediate_size) ->
|
|
(num_experts, intermediate_size, hidden_size)
|
|
|
|
Returns:
|
|
hidden_states: output hidden states after MLP.
|
|
"""
|
|
|
|
if dynamic_scale is None:
|
|
unquantized_hidden_states = hidden_states
|
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
|
hidden_states)
|
|
# Dispose the original unquantized hidden states
|
|
# to save npu memory because they're no longer used.
|
|
dispose_tensor(unquantized_hidden_states)
|
|
else:
|
|
pertoken_scale = dynamic_scale
|
|
|
|
bias1, bias2 = None, None
|
|
_output_dtype = w2_scale.dtype
|
|
|
|
if w1_scale_bias is not None:
|
|
if group_list_type == 0:
|
|
group_list = torch.cat(
|
|
[group_list[:1], torch.diff(group_list, dim=0)])
|
|
group_list_type = 1
|
|
bias1 = [w1_scale_bias]
|
|
bias2 = [w2_scale_bias]
|
|
# TODO w4a8 scene: dynamic acquisition of dtype in the future
|
|
_output_dtype = torch.bfloat16
|
|
|
|
# gmm1: gate_up_proj
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
x=[hidden_states],
|
|
weight=[w1],
|
|
scale=[w1_scale],
|
|
bias=bias1,
|
|
per_token_scale=[pertoken_scale],
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=group_list,
|
|
output_dtype=_output_dtype)[0]
|
|
|
|
# act_fn: swiglu
|
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
|
hidden_states)
|
|
|
|
# gmm2: down_proj
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
x=[hidden_states],
|
|
weight=[w2],
|
|
scale=[w2_scale],
|
|
bias=bias2,
|
|
per_token_scale=[swiglu_out_scale],
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=group_list,
|
|
output_dtype=_output_dtype)[0]
|
|
|
|
return hidden_states
|
|
|
|
|
|
class AscendW8A8DynamicLinearMethod:
|
|
"""Linear method for Ascend W8A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
|
|
@staticmethod
|
|
def get_weight(input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
params_dict = {
|
|
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
|
}
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def get_perchannel_param(
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
) -> Dict[str, Any]:
|
|
params_dict = {}
|
|
params_dict["weight_scale"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
return params_dict
|
|
|
|
def get_pergroup_param(self, input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def apply(
|
|
layer: torch.nn.Module,
|
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = 0,
|
|
) -> torch.Tensor:
|
|
config = getattr(layer, "_ascend_quant_config", {})
|
|
if not isinstance(x, tuple):
|
|
output_dtype = config.get("output_dtype", x.dtype)
|
|
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
|
else:
|
|
assert "output_dtype" in config.keys(), (
|
|
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
|
f"for pre-quantized input, got config [{config}]")
|
|
output_dtype = config["output_dtype"]
|
|
quantized_x, dynamic_scale = x
|
|
pertoken_scale = (dynamic_scale
|
|
if config.get("pertoken_scale", True) else None)
|
|
|
|
output = torch_npu.npu_quant_matmul(
|
|
quantized_x,
|
|
layer.weight,
|
|
layer.weight_scale,
|
|
pertoken_scale=pertoken_scale,
|
|
bias=bias,
|
|
output_dtype=output_dtype,
|
|
)
|
|
return ((output, dynamic_scale)
|
|
if config.get("return_scale", False) else output)
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
if self.transpose_weight:
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
|
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
|
|
|
|
|
class AscendW8A8DynamicFusedMoEMethod:
|
|
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
|
|
self.ep_group = get_ep_group()
|
|
|
|
try:
|
|
device_group = get_mc2_group().device_group
|
|
# TODO: Try local_rank = ep_group.rank_in_group
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
|
backend = device_group._get_backend(torch.device("npu"))
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
|
local_rank)
|
|
except AttributeError:
|
|
self.moe_all_to_all_group_name = ""
|
|
|
|
@staticmethod
|
|
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight"] = torch.empty(num_experts,
|
|
2 *
|
|
intermediate_size_per_partition,
|
|
hidden_sizes,
|
|
dtype=torch.int8)
|
|
param_dict["w2_weight"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition,
|
|
dtype=torch.int8)
|
|
return param_dict
|
|
|
|
@staticmethod
|
|
def get_dynamic_quant_param(num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w13_weight_offset"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=params_dtype)
|
|
return param_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = True,
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
shared_experts: Optional[Any] = None,
|
|
quantized_x_for_share: Optional[Any] = None,
|
|
dynamic_scale_for_share: Optional[Any] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert router_logits.shape[
|
|
1] == global_num_experts, "Number of global experts mismatch"
|
|
|
|
topk_weights, topk_ids, row_idx = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
|
|
fused_moe_state = get_forward_context().fused_moe_state
|
|
shared_gate_up, shared_dequant_scale = None, None
|
|
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
|
share_up_out, _ = shared_experts.gate_up_proj(
|
|
(quantized_x_for_share, dynamic_scale_for_share))
|
|
shared_gate_up, shared_dequant_scale = share_up_out[
|
|
0], share_up_out[1]
|
|
|
|
# this is a naive implementation for experts load balance so as
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
# currently it is only activated when doing profile runs.
|
|
if enable_force_load_balance:
|
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
return unified_fused_experts_eager(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2=layer.w2_weight,
|
|
w2_scale=layer.w2_weight_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
row_idx=row_idx,
|
|
expert_map=expert_map,
|
|
log2phy=log2phy,
|
|
global_redundant_expert_num=global_redundant_expert_num,
|
|
shared_experts=shared_experts,
|
|
shared_gate_up=shared_gate_up,
|
|
shared_dequant_scale=shared_dequant_scale,
|
|
mc2_mask=kwargs.get("mc2_mask", None))
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
if self.transpose_weight:
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
if envs_ascend.VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP:
|
|
torch_npu.npu_format_cast_(layer.w2_weight, ACL_FORMAT_FRACTAL_NZ)
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
|
layer.w13_weight_scale.data.shape[0], -1)
|
|
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
|
torch.float32)
|
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
|
layer.w13_weight_offset.data.shape[0], -1)
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
|
layer.w2_weight_scale.data.shape[0], -1)
|
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
|
layer.w2_weight_offset.data.shape[0], -1)
|