# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from typing import Any, Dict, Optional import torch import torch_npu from vllm.config import get_current_vllm_config from .base import AscendLinearScheme from .registry import register_scheme @register_scheme("W8A8_MXFP8", "linear") class AscendW8A8MXFP8DynamicLinearMethod(AscendLinearScheme): """Linear method for Ascend W8A8_MXFP8 (Microscaling FP8) quantization. This scheme uses microscaling FP8 quantization with per-group scales. The activation is dynamically quantized to FP8 (E4M3FN format) with microscaling, and weights are stored in FP8 format with per-group scales. """ model_dtype = None def __init__(self): vllm_config = get_current_vllm_config() self.group_size = vllm_config.quant_config.quant_description.get( "group_size", 32) def get_weight(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = { "weight": torch.empty(output_size, input_size, dtype=torch.float8_e4m3fn) } return params_dict def get_pergroup_param(self, input_size: int, output_size: int, params_dtype: torch.dtype, layer_type: Optional[str] = None) -> Dict[str, Any]: params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, input_size // self.group_size, dtype=torch.uint8) return params_dict def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, ) -> torch.Tensor: quantized_x, dynamic_scale = torch_npu.npu_dynamic_mx_quant( x, dst_type=torch.float8_e4m3fn) pertoken_scale = dynamic_scale output_dtype = x.dtype output = torch_npu.npu_quant_matmul( quantized_x, layer.weight, layer.weight_scale, scale_dtype=torch_npu.float8_e8m0fnu, pertoken_scale=pertoken_scale, pertoken_scale_dtype=torch_npu.float8_e8m0fnu, bias=bias, output_dtype=output_dtype, group_sizes=[1, 1, self.group_size]) return output def process_weights_after_loading(self, layer): n_dim, k_dim = layer.weight_scale.data.shape layer.weight_scale.data = layer.weight_scale.data.reshape( n_dim, k_dim // 2, 2) layer.weight.data = layer.weight.data.transpose(0, 1) layer.weight_scale.data = layer.weight_scale.data.transpose(0, 1)