### What this PR does / why we need it?
[Feature] Adapt DispathGmmCombineDecode opertor to align with weight
scale dtype of small operators.
- **Before**: weight scale must be float32
- **After**: weight scale can be float32/float16 when x is float16,
float32/bfloat16 when x is float32/bfloat16. And w1 scale can use
different dtype with w2 scale.
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
#### Perf
> When scale is of type fp16 or bf16, it will be cast to fp32 internally
within the operator, while the subsequent computations remain unchanged.
Therefore, this PR will introduce an additional cast operation but halve
the memory copy operations for scale . Furthermore, since the scale data
is only a few KB in size and participates in relatively few
computations, its impact is almost negligible compared to major
operations like matrix multiplication. Thus, the theoretical performance
change should be minimal.
test single operator cases from qwen3-235b,
- single A3 node(ep16), 64 moe experts, 4 experts / die (like qwen3-235b
ep32)
- batch=18/32, token_hidden_size 4096, moe_intermediate_size 1536
The test was conducted for 100 rounds, and the average of the last 95
rounds was taken.
| | bs18(us)| bs32(us)|
| -----| -----| -----|
|Without this PR|96.28|108.83|
|With this PR|96.06|107.90|
Note: Single-operator benchmarks represent an ideal scenario. They are
usually only useful for referencing relative changes and may not fully
align with performance data observed within the full model.
#### Acc
test qwen3-235b eplb on a single A3 node(ep16),
with dispatch_gmm_combine_decode
| dataset | version | metric | mode | vllm-api-stream-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
335 lines
14 KiB
Python
335 lines
14 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import torch
|
|
import torch_npu
|
|
from vllm.config import CompilationMode, get_current_vllm_config
|
|
from vllm.distributed import get_ep_group
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
|
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
|
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
|
|
zero_experts_compute)
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
|
|
|
|
|
|
class AscendW8A8DynamicLinearMethod:
|
|
"""Linear method for Ascend W8A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
pass
|
|
|
|
@staticmethod
|
|
def get_weight(input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
params_dict = {
|
|
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
|
}
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def get_perchannel_param(
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
) -> Dict[str, Any]:
|
|
params_dict = {}
|
|
params_dict["weight_scale"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
return params_dict
|
|
|
|
def get_pergroup_param(self,
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
layer_type: Optional[str] = None) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def apply(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = 0,
|
|
) -> torch.Tensor:
|
|
quantized_x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
|
|
output = torch_npu.npu_quant_matmul(
|
|
quantized_x,
|
|
layer.weight,
|
|
layer.weight_scale,
|
|
pertoken_scale=pertoken_scale,
|
|
bias=bias,
|
|
output_dtype=x.dtype,
|
|
)
|
|
return output
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
|
# cast quantized weight tensors in NZ format for higher inference speed
|
|
layer.weight.data = maybe_trans_nz(layer.weight.data)
|
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
|
|
|
|
|
class AscendW8A8DynamicFusedMoEMethod:
|
|
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.ep_group = get_ep_group()
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
ascend_config = get_ascend_config()
|
|
self.use_aclgraph = (vllm_config.compilation_config.mode
|
|
== CompilationMode.VLLM_COMPILE
|
|
and not vllm_config.model_config.enforce_eager)
|
|
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
|
|
|
|
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
|
|
self.in_dtype = vllm_config.model_config.dtype
|
|
self.supports_eplb = True
|
|
|
|
try:
|
|
device_group = get_mc2_group().device_group
|
|
# TODO: Try local_rank = ep_group.rank_in_group
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
|
backend = device_group._get_backend(torch.device("npu"))
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
|
local_rank)
|
|
except AttributeError:
|
|
self.moe_all_to_all_group_name = ""
|
|
|
|
@staticmethod
|
|
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight"] = torch.empty(num_experts,
|
|
2 *
|
|
intermediate_size_per_partition,
|
|
hidden_sizes,
|
|
dtype=torch.int8)
|
|
param_dict["w2_weight"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition,
|
|
dtype=torch.int8)
|
|
return param_dict
|
|
|
|
@staticmethod
|
|
def get_dynamic_quant_param(num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w13_weight_offset"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=params_dtype)
|
|
return param_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = False,
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
pertoken_scale: Optional[Any] = None,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
|
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
|
if zero_expert_num == 0 or zero_expert_type is None:
|
|
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, \
|
|
"Number of global experts mismatch (excluding redundancy)"
|
|
|
|
if self.multistream_overlap_gate:
|
|
fc3_context = get_flash_common3_context()
|
|
assert fc3_context is not None
|
|
topk_weights = fc3_context.topk_weights
|
|
topk_ids = fc3_context.topk_ids
|
|
else:
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
routed_scaling_factor=routed_scaling_factor,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
global_num_experts=global_num_experts)
|
|
assert topk_ids is not None
|
|
assert topk_weights is not None
|
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
|
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
|
|
expert_indices=topk_ids,
|
|
expert_scales=topk_weights,
|
|
num_experts=global_num_experts,
|
|
zero_expert_type=zero_expert_type,
|
|
hidden_states=x,
|
|
)
|
|
# this is a naive implementation for experts load balance so as
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
# currently it is only activated when doing profile runs.
|
|
if enable_force_load_balance:
|
|
random_matrix = torch.rand(topk_ids.size(0),
|
|
global_num_experts -
|
|
global_redundant_expert_num,
|
|
device=topk_ids.device)
|
|
topk_ids = torch.argsort(
|
|
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
|
|
|
assert topk_weights is not None
|
|
topk_weights = topk_weights.to(self.in_dtype)
|
|
|
|
moe_comm_method = get_forward_context().moe_comm_method
|
|
if self.dynamic_eplb:
|
|
w1 = layer.w13_weight_list
|
|
w1_scale = layer.w13_weight_scale_fp32_list
|
|
w2 = layer.w2_weight_list
|
|
w2_scale = layer.w2_weight_scale_list
|
|
else:
|
|
w1 = [layer.w13_weight]
|
|
w1_scale = [layer.w13_weight_scale_fp32]
|
|
w2 = [layer.w2_weight]
|
|
w2_scale = [layer.w2_weight_scale]
|
|
|
|
fused_scale_flag = (get_forward_context().moe_comm_type
|
|
== MoECommType.FUSED_MC2
|
|
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1)
|
|
final_hidden_states = moe_comm_method.fused_experts(
|
|
hidden_states=x,
|
|
pertoken_scale=pertoken_scale,
|
|
w1=w1,
|
|
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
|
w2=w2,
|
|
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
use_int8_w8a8=True,
|
|
expert_map=expert_map,
|
|
log2phy=log2phy,
|
|
dynamic_eplb=self.dynamic_eplb,
|
|
mc2_mask=kwargs.get("mc2_mask", None))
|
|
if zero_expert_num > 0 and zero_expert_type is not None:
|
|
final_hidden_states += zero_expert_result
|
|
return final_hidden_states
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(1,
|
|
2).contiguous()
|
|
# TODO(zzzzwwjj): Currently, `torch_npu.npu_grouped_matmul_swiglu_quant`
|
|
# can only support weight nz.
|
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
layer.w2_weight.data = torch_npu.npu_format_cast(
|
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
|
layer.w13_weight_scale.data.shape[0], -1)
|
|
layer.w13_weight_scale_fp32 = layer.w13_weight_scale.data.to(
|
|
torch.float32)
|
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
|
layer.w13_weight_offset.data.shape[0], -1)
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
|
layer.w2_weight_scale.data.shape[0], -1)
|
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
|
layer.w2_weight_offset.data.shape[0], -1)
|
|
|
|
layer.fused_w1_scale = scale_from_float_to_int64(
|
|
layer.w13_weight_scale.data)
|
|
layer.fused_w2_scale = scale_from_float_to_int64(
|
|
layer.w2_weight_scale.data)
|
|
|
|
if self.dynamic_eplb:
|
|
layer.w13_weight_list = [
|
|
weight.clone()
|
|
for weight in layer.w13_weight.data.unbind(dim=0)
|
|
]
|
|
layer.w2_weight_list = [
|
|
weight.clone() for weight in layer.w2_weight.data.unbind(dim=0)
|
|
]
|
|
layer.w13_weight_scale_fp32_list = [
|
|
weight.clone()
|
|
for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
|
|
]
|
|
layer.w2_weight_scale_list = [
|
|
weight.clone()
|
|
for weight in layer.w2_weight_scale.data.unbind(dim=0)
|
|
]
|
|
del layer.w13_weight
|
|
del layer.w2_weight
|
|
del layer.w13_weight_scale
|
|
del layer.w13_weight_scale_fp32
|
|
del layer.w2_weight_scale
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
def scale_from_float_to_int64(scale):
|
|
import numpy as np
|
|
scale = torch.from_numpy(
|
|
np.frombuffer(scale.cpu().to(torch.float32).numpy().tobytes(),
|
|
dtype=np.int32).astype(np.int64)).to(scale.device)
|
|
return scale
|