Files
xc-llm-ascend/vllm_ascend/quantization/methods/w8a8_mxfp8.py
linfeng-yuan 88d03a783f [refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
2026-03-20 23:23:57 +08:00

239 lines
9.7 KiB
Python

#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from typing import Any
import torch
import torch_npu
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.distributed import get_ep_group
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.device.mxfp_compat import (
FLOAT8_E8M0FNU_DTYPE,
ensure_mxfp8_linear_available,
ensure_mxfp8_moe_available,
)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
@register_scheme("W8A8_MXFP8", "linear")
class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme):
"""Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization.
This scheme uses microscaling FP8 quantization with per-group scales.
The activation is dynamically quantized to FP8 (E4M3FN format) with
microscaling, and weights are stored in FP8 format with per-group scales.
"""
model_dtype = None
def __init__(self):
ensure_mxfp8_linear_available("W8A8_MXFP8 linear quantization")
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn)}
return params_dict
def get_pergroup_param(
self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: str | None = None
) -> dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, input_size // self.group_size, dtype=torch.uint8)
return params_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
tp_rank: int | None = 0,
) -> torch.Tensor:
# reshape x for Qwen VL models
original_shape = x.shape
if x.dim() > 2:
x = x.view(-1, x.shape[-1])
quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant(x, dst_type=torch.float8_e4m3fn)
pertoken_scale = dynamic_scale
output_dtype = x.dtype
if bias is not None and bias.dtype != torch.float32:
bias = bias.to(torch.float32)
output = torch_npu.npu_quant_matmul(
quantized_x,
layer.weight,
layer.weight_scale,
scale_dtype=FLOAT8_E8M0FNU_DTYPE,
pertoken_scale=pertoken_scale,
pertoken_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
bias=bias,
output_dtype=output_dtype,
group_sizes=[1, 1, self.group_size],
)
# reshape output for Qwen VL models
if len(original_shape) > 2:
output = output.view(*original_shape[:-1], -1)
return output
def process_weights_after_loading(self, layer):
n_dim, k_dim = layer.weight_scale.data.shape
layer.weight_scale.data = layer.weight_scale.data.reshape(n_dim, k_dim // 2, 2)
layer.weight.data = layer.weight.data.transpose(0, 1)
layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)
@register_scheme("W8A8_MXFP8", "moe")
class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
"""FusedMoe method for Ascend W8A8_DYNAMIC."""
model_dtype = None
quant_type: QuantType = QuantType.MXFP8
def __init__(self):
ensure_mxfp8_moe_available("W8A8_MXFP8 MoE quantization")
self.ep_group = get_ep_group()
vllm_config = get_current_vllm_config()
self.group_size = vllm_config.quant_config.quant_description.get("group_size", 32)
ascend_config = get_ascend_config()
self.use_aclgraph = (
vllm_config.compilation_config.mode == CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager
)
self.dynamic_eplb = ascend_config.eplb_config.dynamic_eplb
@staticmethod
def get_weight(
num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes, dtype=torch.float8_e4m3fn
)
param_dict["w2_weight"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition, dtype=torch.float8_e4m3fn
)
return param_dict
def get_dynamic_quant_param(
self, num_experts: int, intermediate_size_per_partition: int, hidden_sizes: int, params_dtype: torch.dtype
) -> dict[str, Any]:
param_dict = {}
param_dict["w13_weight_scale"] = torch.empty(
num_experts, 2 * intermediate_size_per_partition, hidden_sizes // self.group_size, dtype=torch.uint8
)
param_dict["w2_weight_scale"] = torch.empty(
num_experts, hidden_sizes, intermediate_size_per_partition // self.group_size, dtype=torch.uint8
)
return param_dict
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
expected = global_num_experts - global_redundant_expert_num
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
topk_weights = topk_weights.to(x.dtype)
moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts(
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=self.quant_type,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
mxfp_act_quant_type=torch.float8_e4m3fn,
mxfp_weight_quant_type=torch.float8_e4m3fn,
mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
mxfp_use_bf16=(x.dtype == torch.bfloat16),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
)
def process_weights_after_loading(self, layer):
g_num, n_size, k_size = layer.w13_weight_scale.shape
layer.w13_weight_scale.data = layer.w13_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
g_num, n_size, k_size = layer.w2_weight_scale.shape
layer.w2_weight_scale.data = layer.w2_weight_scale.data.reshape(g_num, n_size, k_size // 2, 2)
layer.w13_weight.data = layer.w13_weight.data.transpose(1, 2)
layer.w2_weight.data = layer.w2_weight.data.transpose(1, 2)
layer.w13_weight_scale.data = layer.w13_weight_scale.data.transpose(1, 2)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.transpose(1, 2)