[Model] Support DeepSeek-V4

This commit is contained in:
chenxb002
2026-04-24 09:50:34 +08:00
commit b9925203b8
172 changed files with 44780 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from vllm.model_executor.layers.quantization import (
QUANTIZATION_METHODS, register_quantization_config
)
MLU_QUANTIZATION_METHODS= [
"smoothquant",
"weightonly",
"awq_mlu",
"gptq_mlu",
]
def register_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method not in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.append(quant_method)
def remove_fake_mlu_quantization_methods():
for quant_method in MLU_QUANTIZATION_METHODS:
if quant_method in QUANTIZATION_METHODS:
QUANTIZATION_METHODS.remove(quant_method)
def register_real_mlu_quantization_methods():
remove_fake_mlu_quantization_methods()
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
register_quantization_config("weightonly")(WeightOnlyConfig)
register_quantization_config("smoothquant")(SmoothQuantConfig)
register_quantization_config("awq_mlu")(AWQMluConfig)
register_quantization_config("gptq_mlu")(GPTQMluConfig)

View File

@@ -0,0 +1,412 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("awq_mlu")
class AWQMluConfig(QuantizationConfig):
"""Config class for AWQMlu.
Reference: https://arxiv.org/abs/2306.00978
"""
# num_bits -> type
TYPE_MAP = {
4: {
False: scalar_types.uint4b8,
True: scalar_types.uint4,
},
8: {
False: scalar_types.uint8b128,
True: scalar_types.uint8,
}
}
VERSION = ["gemm"]
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
version: str = "gemm",
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.pack_factor = 32 // self.weight_bits
self.version = version
self.support_scale_zeros = False
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is supported for "
f"AWQMlu, but got {self.weight_bits} bits.")
if self.version not in self.VERSION:
raise ValueError(
"Currently, only gemm, gemv version is supported for "
f"AWQMlu, but got verion:{self.version}.")
if self.version in ["gemm"]:
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
else:
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
def __repr__(self) -> str:
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}), "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
return "awq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
version = cls.get_from_keys_or(config, ["version"],
default="gemm")
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "awq"
or user_quant == "awq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_mlu"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_mlu for"
" faster inference")
return None
@classmethod
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
version = quant_config.get("version", "gemm")
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
return False
if num_bits not in cls.TYPE_MAP:
return False
if version not in cls.VERSION:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
group_size=group_size,
has_zp=has_zp)
class AWQMluLinearMethod(LinearMethodBase):
"""Linear method for AWQMlu.
Args:
quant_config: The AWQMlu quantization config.
"""
def __init__(self, quant_config: AWQMluConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
qzeros = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
scales = GroupQuantScaleParameter(data=torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
packed_qweight, scale_zeros = self.extract_autoawq(layer)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autoawq(self, layer: torch.nn.Module):
qweight = layer.qweight.data
qzeros = layer.qzeros.data
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
if izeros is not None:
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device)
dtype = torch.int16 if bits == 8 else torch.int8
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
iweights = iweights.view(iweights.shape[0], -1)
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
# unpacking columnwise
if qzeros is not None:
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
izeros = izeros.view(izeros.shape[0], -1)
if not self.quant_config.zero_point:
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
else:
izeros = None
return iweights, izeros
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
reverse_order_tensor = reverse_order_tensor.view(-1)
rweights = iweights[:, reverse_order_tensor]
if izeros is not None:
rzeros = izeros[:, reverse_order_tensor]
return rweights, rzeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,753 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import functools
from functools import partial
import importlib.util
from typing import Any, Callable, Optional, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from typing import Any, Dict, List, Optional, Callable
from vllm import envs
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.quantization.fp8 import (
get_flashinfer_moe_backend,
ACTIVATION_SCHEMES,
Fp8Config,
Fp8LinearMethod,
Fp8MoeBackend,
Fp8MoEMethod,
)
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
FlashinferMoeBackend,
flashinfer_cutlass_moe_fp8,
get_flashinfer_moe_backend,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
validate_fp8_block_shape
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
maybe_create_device_identity, Fp8LinearOp)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter, PerTensorScaleParameter)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import (
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.flashinfer import has_flashinfer_moe
from vllm.utils.import_utils import has_deep_gemm
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
import vllm_mlu._mlu_ops as mlu_ops
logger = init_logger(__name__)
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
and (
current_platform.is_device_capability(100)
or current_platform.is_device_capability(90)
)
and envs.VLLM_USE_FLASHINFER_MOE_FP8
and has_flashinfer_moe()
):
backend = get_flashinfer_moe_backend()
if backend == FlashinferMoeBackend.TENSORRT_LLM:
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
return Fp8MoeBackend.FLASHINFER_TRTLLM
else:
if block_quant and current_platform.is_device_capability(100):
raise ValueError(
"FlashInfer FP8 MoE throughput backend does not "
"support block quantization. Please use "
"VLLM_FLASHINFER_MOE_BACKEND=latency "
"instead."
)
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
return Fp8MoeBackend.FLASHINFER_CUTLASS
# weight-only path for older GPUs without native FP8
use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: disable marlin for MLU backend.
'''
if current_platform.is_rocm() or current_platform.is_out_of_tree():
use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if use_marlin:
logger.info_once("Using Marlin backend for FP8 MoE")
return Fp8MoeBackend.MARLIN
# deepGEMM on supported platforms with block-quantized weights
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
if not has_deep_gemm():
logger.warning_once("DeepGEMM backend requested but not available.")
elif is_deep_gemm_supported():
logger.info_once("Using DeepGEMM backend for FP8 MoE")
return Fp8MoeBackend.DEEPGEMM
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
if (
current_platform.is_cuda()
and current_platform.is_device_capability(100)
and block_quant
):
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
# default to Triton
logger.info_once("Using Triton backend for FP8 MoE")
return Fp8MoeBackend.TRITON
Fp8Config____init____org = Fp8Config.__init__
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: list[str] | None = None,
weight_block_size: list[int] | None = None,
activation_quant_method: Optional[str] = None,
weight_quant_method: Optional[str] = None,
) -> None:
super(Fp8Config, self).__init__()
Fp8Config____init____org(
self,
is_checkpoint_fp8_serialized,
activation_scheme,
ignored_layers,
weight_block_size
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add class members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.activation_quant_method = activation_quant_method
self.weight_quant_method = weight_quant_method
assert (self.weight_block_size or \
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
"input dynamic per-token weight per-channel quantization yet."
'''
==================
End of MLU Hijack
==================
'''
@classmethod
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
cls, config: Dict[str, Any]
) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = "fp8" in quant_method
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
if not ignored_layers:
ignored_layers = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
activation_quant_method = cls.get_from_keys_or(config,
["activation_quant_method"],
'per_token')
weight_quant_method = cls.get_from_keys_or(config,
["weight_quant_method"],
None)
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme,
ignored_layers=ignored_layers,
weight_block_size=weight_block_size,
activation_quant_method=activation_quant_method,
weight_quant_method=weight_quant_method)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
tp_group = extra_weight_attrs.get("tp_group", None)
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: add tp_group.
'''
# WEIGHT
if self.quant_config.is_checkpoint_fp8_serialized:
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
else:
# For non-serialized checkpoints, use original dtype
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
'''
==================
End of MLU Hijack
==================
'''
layer.register_parameter("weight", weight)
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
if not self.block_quant:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Support weight per channel quantization.
@brief: Add tp_group to enable custom split.
'''
if self.weight_per_channel:
scale = ChannelQuantScaleParameter(
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
tp_group=tp_group,
)
else:
scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes),
dtype=torch.float32),
weight_loader=weight_loader,
)
scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale)
'''
==================
End of MLU Hijack
==================
'''
else:
assert not self.act_q_static
assert self.weight_block_size is not None
scale = create_fp8_scale_parameter(
BlockQuantScaleParameter,
output_partition_sizes,
input_size_per_partition,
self.weight_block_size,
weight_loader,
)
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale)
# INPUT ACTIVATION SCALE
if self.act_q_static:
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
set_weight_attrs(scale, {"scale_type": "input_scale"})
layer.register_parameter("input_scale", scale)
else:
layer.register_parameter("input_scale", None)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
self,
quant_config: Fp8Config
):
self.quant_config = quant_config
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.out_dtype = torch.get_default_dtype()
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
self.use_marlin = (
not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if current_platform.is_rocm():
self.use_marlin = False
if vllm_is_batch_invariant():
self.use_marlin = False
# AITER is only supported on ROCm and only for FP8_FNUZ
# and at the moment are MI300 series
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
self.use_deep_gemm = is_deep_gemm_supported()
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant = self.weight_block_size is not None
if self.block_quant:
# Marlin doesn't support block-wise fp8
self.use_marlin = False
self.act_q_static = self.quant_config.activation_scheme == "static"
if self.weight_block_size:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
# Use per-token quantization for better perf if dynamic and cutlass
if not self.act_q_static and cutlass_fp8_supported():
self.act_q_group_shape = GroupShape.PER_TOKEN
else:
self.act_q_group_shape = GroupShape.PER_TENSOR
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add config members activation_quant_method and weight_quant_method to
indicate the granularity of quantization.
'''
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
if self.weight_per_channel and self.activation_per_token:
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
if self.block_quant:
assert not self.act_q_static
assert self.weight_block_size is not None
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.act_q_static,
act_quant_group_shape=self.act_q_group_shape,
)
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
self,
layer: Module,
) -> None:
'''
=============================
Modify by vllm_mlu
=============================
@brief: For dynamic activation and channel-wise weight quantization,
additional processing is not needed.
'''
if (self.quant_config.is_checkpoint_fp8_serialized
and self.weight_per_channel
and self.quant_config.activation_scheme == "dynamic"):
return
'''
==================
End of MLU Hijack
==================
'''
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert residual is None, "Fp8Linear residual is not supported yet."
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if vllm_is_batch_invariant():
if self.block_quant:
assert self.weight_block_size is not None
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
else:
# per-tensor/channel: dequant to BF16 and run GEMM
weight_fp8 = layer.weight.to(torch.bfloat16)
weight_scale = layer.weight_scale.to(torch.bfloat16)
if weight_scale.numel() == 1:
# Per-tensor: simple scalar multiplication
weight_bf16 = weight_fp8 * weight_scale
else:
# Multiple scales (fused modules like QKV)
# Try to infer correct broadcasting
# weight is [K, N], scale could be [num_logical_weights]
# Need to figure out how to broadcast - for now just try
# direct multiplication
if (
weight_scale.dim() == 1
and weight_scale.shape[0] == weight_fp8.shape[0]
):
# Per-row scaling
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
else:
# Fallback
weight_bf16 = weight_fp8 * weight_scale
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
if self.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
if self.block_quant:
assert self.weight_block_size is not None
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
apply_w8a8_block_fp8_linear)
return apply_w8a8_block_fp8_linear(
input=x,
weight=layer.weight,
block_size=self.quant_config.weight_block_size,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use activation per token quantization based on quantization config.
'''
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
weight_per_channel=self.weight_per_channel,
activation_per_token=self.activation_per_token)
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
self,
quant_config: Fp8Config,
layer: torch.nn.Module
):
super(Fp8MoEMethod, self).__init__(layer.moe_config)
self.layer = layer
self.quant_config = quant_config
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
if self.block_quant:
assert self.weight_block_size == [128, 128], (
f"Only support weight_block_size == [128, 128], "
f"got {self.weight_block_size}"
)
self.flashinfer_moe_fn = partial(
flashinfer_cutlass_moe_fp8,
moe=self.moe,
use_deepseek_fp8_block_scale=self.block_quant,
)
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
self.allow_cutlass_block_scaled_grouped_gemm = (
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
)
'''
=============================
Modify by vllm_mlu
=============================
@brief: In mlu, always set self.use_marlin as False.
'''
self.use_marlin = False
'''
==================
End of MLU Hijack
==================
'''
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: int | None = None,
num_expert_group: int | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)
'''
=============================
Modify by vllm_mlu
=============================
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
'''
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
if scoring_func == "softmax":
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
route_scale=routed_scaling_factor,
)
elif scoring_func == "sigmoid":
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
routed_scaling_factor,
e_score_correction_bias,
)
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")
# gen_idx
ori_input_shape = x.shape
x = x.reshape(-1, x.size(-1))
router_logits = router_logits.reshape(-1, router_logits.size(-1))
expert_num = router_logits.size(-1)
tokens_num = x.size(0)
expert_size = layer.w13_weight.size(0)
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
topk_ids, expert_num
)
expand_hidden_states = mlu_ops.moe_expand_input(
x, expand_idx, cumsum_token_count, 0, expert_size
)
quant_input, input_scale = _fp8_quantize(
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm1_out = mlu_ops.smooth_quant_group_gemm(
quant_input,
layer.w13_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=input_scale.T.contiguous(),
b_scale=layer.w13_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
act_out_quantize, act_out_scale = _fp8_quantize(
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
)
gemm2_out = mlu_ops.smooth_quant_group_gemm(
act_out_quantize,
layer.w2_weight,
token_count,
expand_idx=None,
c=None,
alpha=None,
beta=None,
a_scale=act_out_scale.T.contiguous(),
b_scale=layer.w2_weight_scale_inv,
dtype=x.dtype,
max_m=tokens_num,
)
output = mlu_ops.moe_combine_result(
gemm2_out,
topk_weights,
combine_idx,
residual=None,
cusum_token_count=cumsum_token_count,
start_expert_id=0,
expert_size=expert_size,
bias=None,
)
return output.view(ori_input_shape)
"""
==================
End of MLU Hijack
==================
"""
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
)
MluHijackObject.apply_hijack(
Fp8Config,
Fp8Config.from_config,
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.create_weights,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
)
MluHijackObject.apply_hijack(
Fp8LinearMethod,
Fp8LinearMethod.process_weights_after_loading,
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.__init__,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
)
MluHijackObject.apply_hijack(
Fp8MoEMethod,
Fp8MoEMethod.apply,
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
)

View File

@@ -0,0 +1,440 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from fractions import Fraction
from typing import Any, Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
def query_mlu_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if has_zp:
# AWQ style, unsigned + zero-point
return [scalar_types.uint4, scalar_types.uint8]
else:
# GPTQ style, unsigned + symmetric bias
return [scalar_types.uint4b8, scalar_types.uint8b128]
def check_mlu_supported(
quant_type: ScalarType,
group_size: Optional[int],
has_zp: bool,
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
supported_types = query_mlu_supported_quant_types(
has_zp, device_capability)
if quant_type not in supported_types:
return (False, f"Mlu does not support weight_bits = {quant_type}. "
f"Only types = {supported_types} "
f"are supported (for group_size = {group_size}, "
f"device_capability = {device_capability}, zp = {has_zp}).")
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
return (False, f"Mlu does not support group_size = {group_size}. "
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
"are supported.")
return True
# @register_quantization_config("gptq_mlu")
class GPTQMluConfig(QuantizationConfig):
"""Config class for GPTQMlu.
Reference: https://arxiv.org/abs/2210.17323
"""
# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
(4, False): scalar_types.uint4b8,
(8, False): scalar_types.uint8b128,
}
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
is_sym: bool,
lm_head_quantized: bool,
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
self.lm_head_quantized = lm_head_quantized
self.pack_factor = Fraction(32, self.weight_bits)
self.support_scale_zeros = False
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
if self.weight_bits not in [4, 8]:
raise ValueError(
"Currently, only 4/8-bit weight quantization is "
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}),"
f"lm_head_quantized={self.lm_head_quantized}")
@classmethod
def get_name(cls) -> str:
return "gptq_mlu"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16, torch.float32]
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quant_config.json", "quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMluLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMluLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
@classmethod
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
if (num_bits, sym) not in cls.TYPE_MAP:
return False
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
group_size=group_size, has_zp=False)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_mlu")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
class GPTQMluLinearMethod(LinearMethodBase):
"""Linear method for GPTQMlu.
Args:
quant_config: The GPTQMlu quantization config.
"""
def __init__(self, quant_config: GPTQMluConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
-1) and (not self.quant_config.desc_act):
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.device = layer.qweight.data.device
packed_qweight, scale_zeros = self.extract_autogptq(layer)
if self.quant_config.use_native:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
layer.qzeros = None
layer.scales = None
else:
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
if scale_zeros is not None:
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
else:
layer.qzeros = None
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_native:
output = mlu_ops.matmul(x, layer.qweight, bias)
if residual is not None:
output = output + residual
else:
output = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
layer.qzeros,
bias,
residual,
"none",
self.quant_config.weight_bits)
return output
def extract_autogptq(self, layer: torch.nn.Module):
scales = layer.scales.data
bits = self.quant_config.weight_bits
group_size = self.quant_config.group_size
# Unpack the qweight and qzeros tensors
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
if self.quant_config.use_native:
if self.quant_config.desc_act:
scales = torch.index_select(scales, 0, layer.g_idx)
if izeros is not None:
izeros = torch.index_select(izeros, 0, layer.g_idx)
else:
scales = scales.repeat_interleave(group_size, dim=0)
if izeros is not None:
izeros = izeros.repeat_interleave(group_size, dim=0)
if izeros is not None:
fweight = (iweight - izeros) * scales
else:
fweight = iweight * scales
# transpose [ci, co] -> [co, ci]
fweight = fweight.transpose(0, 1)
return fweight, None
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
scale_zeros = izeros.to(scales.dtype) * -1 * scales
# transpose [ci, co] -> [co, ci]
scale_zeros = scale_zeros.transpose(0, 1)
else:
# for is_sym is true now, so make iweight to sign value and ignore qzeros
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
scale_zeros = None
# transpose [ci, co] -> [co, ci]
iweight = iweight.to(torch.int8).transpose(0, 1)
if bits == 4:
higher_bit_tensor = iweight[:, 1::2]
lower_bit_tensor = iweight[:, 0::2]
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
else:
packed_qweight = iweight
return packed_qweight, scale_zeros
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
weight = torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
shifts.unsqueeze(-1),
).to(dtype)
weight = torch.bitwise_and(weight, (2**bits) - 1)
weight = weight.reshape(-1, weight.shape[-1])
return weight
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
dtype = torch.int16 if bits == 8 else torch.int8
zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
shifts.unsqueeze(0),
).to(dtype)
zeros = zeros + 1
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
zeros = zeros.reshape(qzeros.shape[0], -1)
return zeros
def combine_low_bits(self, tensor_a, tensor_b):
"""
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
Args:
tensor_a (torch.Tensor): First tensor of type int8.
tensor_b (torch.Tensor): Second tensor of type int8.
Returns:
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
"""
# 确保输入是 int8 类型
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
raise ValueError("Both tensors must be of int8 type.")
# 提取每个 tensor 的低4位
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
# 将 tensor_a 的低4位左移4位
shifted_low_bits_a = low_bits_a << 4
# 组合两个 tensor 的低4位
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
return combined

View File

@@ -0,0 +1,337 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
RowvLLMParameter)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
str_dtype_to_bits,
is_fp8_str_dtype)
# @register_quantization_config("smoothquant")
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant.
"""
def __init__(
self,
quant_mode: str, # smoothquant
input_quant_method: str, # per token/per tensor
group_size: int,
weight_precision: str,
activation_precision: str,
only_expert_per_group: bool,
expert_weight_precision: str,
expert_activation_precision: str,
force_use_weightonly_except_expert: bool,
) -> None:
super().__init__()
self.quant_mode = quant_mode
self.input_quant_method = input_quant_method
self.group_size = group_size
self.weight_precision = weight_precision
self.activation_precision = activation_precision
self.only_expert_per_group = only_expert_per_group
self.expert_weight_precision = expert_weight_precision
self.expert_activation_precision = expert_activation_precision
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
self.weight_bits = str_dtype_to_bits(self.weight_precision)
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
if self.weight_precision == 'int4':
self.weight_dtype = torch.int8
else:
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
if self.expert_weight_precision == 'int4':
self.expert_weight_dtype = torch.int8
else:
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
self.pack_factor = 8 // self.weight_bits
self.expert_pack_factor = 8 // self.expert_weight_bits
def __repr__(self) -> str:
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
f"quant_mode={self.quant_mode}, "
f"group_size={self.group_size}, "
f"weight_precision={self.weight_precision}, "
f"activation_precision={self.activation_precision}, "
f"only_expert_per_group={self.only_expert_per_group}, "
f"expert_weight_precision={self.expert_weight_precision}, "
f"expert_activation_precision={self.expert_activation_precision}, "
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
@classmethod
def get_name(self) -> str:
return "SmoothQuant"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
quant_mode = cls.get_from_keys(config, ["quant_mode"])
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
if expert_weight_precision is None:
expert_weight_precision = weight_precision
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
weight_precision = 'int8'
if expert_activation_precision is None:
expert_activation_precision = activation_precision
return cls(quant_mode=quant_mode,
input_quant_method=input_quant_method,
group_size=group_size,
weight_precision=weight_precision,
activation_precision=activation_precision,
only_expert_per_group=only_expert_per_group,
expert_weight_precision=expert_weight_precision,
expert_activation_precision=expert_activation_precision,
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
if isinstance(layer, LinearBase):
return SmoothQuantLinearMethod(self, prefix)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class SmoothQuantLinearMethod(LinearMethodBase):
"""Linear method for SmoothQuant.
Args:
quant_config: The SmoothQuant quantization config.
"""
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
self.quant_config = quant_config
# for per-tensor case, we can skip quant input for the first attn|ffn linear
# and fusion this step in layernorm to get better performance
self.skip_quant_input = False
self.compute_dtype = torch.get_default_dtype()
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
self.is_group_quant = True
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
self.is_group_quant = True
else:
self.is_group_quant = False
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor != 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
group_num = 1
if self.is_group_quant:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"The input size {input_size_per_partition} is not aligned with the quantized "
f"weight shape. This can be caused by too large "
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
if input_size_per_partition != input_size:
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
qweight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
device="mlu",
dtype=self.weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_group_quant:
per_channel_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
group_num,
device="mlu",
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
per_channel_scale = ChannelQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
device="mlu",
dtype=torch.float32,
),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("per_channel_scale", per_channel_scale)
if self.has_smooth:
smooth = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(smooth, {
"ignore_warning": True,
})
layer.register_parameter("smooth", smooth)
if self.quant_config.input_quant_method == "per_tensor":
scale_to_int = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(scale_to_int, {
"ignore_warning": True,
})
layer.register_parameter("scale_to_int", scale_to_int)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.has_smooth and layer.smooth.dtype != torch.float:
layer.smooth = layer.smooth.to(torch.float)
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
layer.scale_to_int = layer.scale_to_int.to(torch.float)
if layer.per_channel_scale.dtype != torch.float:
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
if self.has_smooth:
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
if self.quant_config.input_quant_method == "per_tensor":
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
use_tp_weight : bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
layer_smooth = layer.smooth if self.has_smooth else None
layer_qweight = layer.qweight
layer_per_channel_scale = layer.per_channel_scale
if use_tp_weight:
if hasattr(layer, 'tp_smooth'):
layer_smooth = layer.tp_smooth
if hasattr(layer, 'tp_qweight'):
layer_qweight = layer.tp_qweight
if hasattr(layer, 'tp_per_channel_scale'):
layer_per_channel_scale = layer.tp_per_channel_scale
quant_input = None
if self.skip_quant_input:
quant_input = x
elif self.quant_config.input_quant_method == "per_token":
if self.is_fp8:
quant_input, input_scale = mlu_ops.scaled_quantize(x,
layer_smooth,
quant_type=self.weight_dtype,
quant_mode='dynamic_per_token')
else:
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
elif self.quant_config.input_quant_method == "per_tensor":
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
else:
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
quant_input_shape = quant_input.shape
if len(quant_input_shape) > 2:
quant_input = quant_input.view(-1, quant_input_shape[-1])
input_scale = input_scale.view(-1)
if residual is not None and len(residual.shape) > 2:
residual = residual.view(-1, residual.shape[-1])
if self.is_fp8:
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
layer_per_channel_scale,
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
bias,
c=residual, act_mode="none",quant_bit_size=8,
alpha=1.0, beta=1.0, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
if output is not None:
out = out.view(output.shape)
output.copy_(out)
out = output
else:
if output is not None:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual)
if len(quant_input_shape) > 2:
out = out.view(*quant_input_shape[:-1], out.shape[-1])
return out

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project

View File

@@ -0,0 +1,111 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
import torch
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
torch.int32, torch.int, torch.int64, torch.long]
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
GEMM_GROUP_SIZE = [64, 128, 256, 512]
_STR_TO_TORCH_DTYPE_DICT = dict(
bfloat16=torch.bfloat16,
float16=torch.float16,
float32=torch.float32,
int64=torch.int64,
int32=torch.int32,
int8=torch.int8,
bool=torch.bool,
e4m3fn=torch.float8_e4m3fn,
e4m3fnuz=torch.float8_e4m3fnuz,
e5m2=torch.float8_e5m2,
e5m2fnuz=torch.float8_e5m2fnuz,
)
TORCH_DTYPE_TO_STR_DICT = {
torch.bfloat16: "bfloat16",
torch.float16: "float16",
torch.float32: "float32",
torch.int64: "int64",
torch.int32: "int32",
torch.int8: "int8",
torch.bool: "bool",
torch.float8_e4m3fn: "e4m3fn",
torch.float8_e4m3fnuz: "e4m3fnuz",
torch.float8_e5m2: "e5m2",
torch.float8_e5m2fnuz: "e5m2fnuz",
}
STR_DTYPE_TO_BITS_DICT = {
"bfloat16": 16,
"float16": 16,
"float32": 32,
"int64": 64,
"int32": 32,
"int8": 8,
'int4': 4,
"bool": 1,
"e4m3fn": 8,
"e4m3fnuz": 8,
"e5m2": 8,
"e5m2fnuz": 8,
}
def str_dtype_to_torch(str_dtype: str):
'''
convert torch dytpe to str dtype
'''
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
dtype = ret if ret is not None else torch.float16
return dtype
def torch_dtype_to_str(dtype: torch.dtype):
'''
convert torch dytpe to str dtype
'''
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
str_dtype = ret if ret is not None else "float16"
return str_dtype
def str_dtype_to_bits(str_dtype):
'''
convert torch dtype to bits size
'''
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
bits = ret if ret is not None else 8
return bits
def is_integer_dtype(dtype: torch.dtype):
'''
check whether is integer or not
'''
return dtype in INTERGER_DTYPES
def is_float_dtype(dtype: torch.dtype):
'''
check whether is float or not
'''
return dtype in FLOAT_DTYPES
def is_fp8_dtype(dtype: torch.dtype):
'''
judge fp8 torch dtype
'''
return dtype in FP8_DTYPE
def is_fp8_str_dtype(str_dtype: str):
'''
judge fp8 str dtype
'''
return str_dtype in FP8_STR_DTYPE

View File

@@ -0,0 +1,424 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/2575
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_per_token_group_quant_fp8_colmajor)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
logger = init_logger(__name__)
'''
=============================
Modify by vllm_mlu
=============================
@brief: get total core for split triton kernel
'''
import triton.backends.mlu.driver as driver
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
'''
==================
End of MLU Hijack
==================
'''
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
weight: torch.Tensor,
block_size: List[int],
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
use_aiter_and_is_supported: bool = False,
) -> torch.Tensor:
assert input_scale is None
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[0]]
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
and weight.shape[1] % 128 == 0)
if current_platform.is_rocm():
# TODO this is never used, as cutlass_block_fp8_supported is False
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
input_2d.shape[:-1])[::-1]
scale_b_shape = (weight_scale.view(-1, 1)
if weight_scale.dim() <= 1 else weight_scale.T).shape
ar, ac = scale_a_shape
br, bc = scale_b_shape
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
or br not in (1, weight.shape[0])):
shape_supported_by_cutlass = False
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=True)
output = ops.cutlass_scaled_mm(q_input,
weight.T,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale.T)
else:
q_input, x_scale = per_token_group_quant_fp8(input_2d,
block_size[1],
column_major_scales=False)
output = w8a8_block_fp8_matmul(q_input,
weight,
x_scale,
weight_scale,
block_size,
output_dtype=input.dtype)
if bias is not None:
output = output + bias
return output.to(dtype=input.dtype).view(*output_shape)
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
eps: float = 1e-10,
dtype: Optional[torch.dtype] = None,
column_major_scales: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype = current_platform.fp8_dtype() if dtype is None else dtype
assert (x.shape[-1] % group_size == 0), (
f"the last dimension of `x` {x.shape[-1]} must be divisible "
f"by `group_size` {group_size}")
assert x.stride(-1) == 1, "`x` groups must be contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
group_per_block = 1
while M >= 65536:
group_per_block *= 2
M = x.numel() // (group_size * group_per_block)
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
x_s = torch.empty(shape, device=x.device,
dtype=torch.float32).permute(-1, -2)
else:
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
'''
=============================
Modify by vllm_mlu
=============================
@brief: set num_warps to 1 for triton-mlu
'''
num_warps = 1
num_stages = 1
'''
==================
End of MLU Hijack
==================
'''
if column_major_scales:
_per_token_group_quant_fp8_colmajor[(M, )](
x,
x_q,
x_s,
group_size,
x.shape[1],
x.stride(0),
x_s.stride(1),
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
'_per_token_group_quant_fp8' kernel
'''
# Check if x is contiguous, if not, create a new tensor for contiguous x
if not x.is_contiguous():
x = x.contiguous()
x_origin_shape = x.shape
x = x.reshape(*x.shape[:-1], -1, group_size)
x_q, x_s = mlu_ops.scaled_quantize(x,
None,
quant_type=dtype,
quant_mode='dynamic_per_token')
x_q = x_q.reshape(x_origin_shape)
'''
==================
End of MLU Hijack
==================
'''
return x_q, x_s
@triton.jit
def _w8a8_block_fp8_matmul(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_As_m,
stride_As_k,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
'''
=============================
Modify by vllm_mlu
=============================
@brief: split for limit the memory usage(65536)
'''
num_block_size_all = num_pid_m * num_pid_n
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
for pid_i in range(0, core_deal_num_block_size):
pid_in_core_deal_block = core_deal_num_block_start + pid_i
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_in_core_deal_block // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
'''
==================
End of MLU Hijack
==================
'''
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
As: torch.Tensor,
Bs: torch.Tensor,
block_size: List[int],
output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
'''
=============================
Modify by vllm_mlu
=============================
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
'_w8a8_block_fp8_matmul' kernel
'''
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert B.ndim == 2 and Bs.ndim == 2
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
else:
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype)
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 1,
}
def grid(META):
return (TOTAL_CORE_NUM, )
_w8a8_block_fp8_matmul[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
'''
==================
End of MLU Hijack
==================
'''
return C

View File

@@ -0,0 +1,178 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
torch_channelwise_w8a8_scaled_mm)
from vllm.platforms import current_platform
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
def mlu_w8a8_scaled_mm(
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
output_shape: list, **kwargs
) -> torch.Tensor:
output = mlu_ops.scaled_matmul(
qinput, # a
weight, # b
scale_a, # a_scale
scale_b, # b_scale
out_dtype, # output_dtype
bias, # bias
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None
)
return output.view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
weight_per_channel: bool, activation_per_token: bool
) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if (
not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
'''
=============================
Modify by vllm_mlu
=============================
@brief: dispatch to mlu_w8a8_scaled_mm
'''
if weight_per_channel and activation_per_token:
return mlu_w8a8_scaled_mm
'''
==================
End of MLU Hijack
==================
'''
return torch_channelwise_w8a8_scaled_mm
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
input_scale: torch.Tensor | None = None,
input_scale_ub: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
weight_per_channel: bool = True,
activation_per_token: bool = True,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
'''
=============================
Modify by vllm_mlu
=============================
@brief: add mlu_fp8_supported
'''
self.mlu_fp8_supported = False
if weight_per_channel and activation_per_token:
self.mlu_fp8_supported = True
'''
==================
End of MLU Hijack
==================
'''
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
if out_dtype is None:
out_dtype = input.dtype
if self.mlu_fp8_supported:
'''
=============================
Modify by vllm_mlu
=============================
@brief: Add support for activation-per-token weight-per-channel quantization.
'''
qinput, x_scale = mlu_ops.scaled_quantize(
input_2d,# x
None, # scale
None, # zero
None, # scale_ub
quant_type=torch.float8_e4m3fn,
quant_mode='dynamic_per_token'
)
output_shape = [*input.shape[:-1], weight.shape[0]]
'''
==================
End of MLU Hijack
==================
'''
else:
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
input_scale_ub,
)
else:
qinput, x_scale = input_2d, input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.preferred_backend, per_tensor_weights, per_tensor_activations,
weight_per_channel, activation_per_token)
return w8a8_scaled_mm_func(
qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
output_shape=output_shape,
)
MluHijackObject.apply_hijack(
Fp8LinearOp,
Fp8LinearOp.apply,
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
)

View File

@@ -0,0 +1,150 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm_mlu import _mlu_ops as mlu_ops
from vllm.logger import init_logger
logger = init_logger(__name__)
# @register_quantization_config("weightonly")
class WeightOnlyConfig(QuantizationConfig):
"""Config class for WeightOnly.
"""
def __init__(
self,
weight_bits: int,
quant_mode: str, # weight_only
) -> None:
super().__init__()
self.weight_bits = weight_bits
self.quant_mode = quant_mode
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
raise ValueError(
"Currently, only 8/4-bit weight quantization is supported for "
f"weight_only, but got {self.weight_bits} bits.")
self.pack_factor = 8 // self.weight_bits
def __repr__(self) -> str:
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
f"quant_mode={self.quant_mode})")
def get_name(self) -> str:
return "WeightOnly"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
try:
quant_mode = cls.get_from_keys(config, ["quant_mode"])
except Exception:
quant_mode = "WeightOnly"
return cls(weight_bits, quant_mode)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
if isinstance(layer, LinearBase):
return WeightOnlyLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class WeightOnlyLinearMethod(LinearMethodBase):
"""Linear method for WeightOnly.
Args:
quant_config: The WeightOnly quantization config.
"""
def __init__(self, quant_config: WeightOnlyConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> Dict[str, Any]:
output_size_per_partition = sum(output_partition_sizes)
if self.quant_config.quant_mode == "WeightOnly":
scale_and_zero_input_dim = None
if output_size != output_size_per_partition:
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
device="mlu",
dtype=torch.int8,
),
requires_grad=False,
)
set_weight_attrs(qweight, {
"input_dim": 1,
"output_dim": 0,
})
scales = Parameter(
torch.empty(
output_size_per_partition,
device="mlu",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if layer.scales.dtype != torch.float:
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
x_shape = x.shape
if len(x_shape) > 2:
x = x.view(-1, x_shape[-1])
out = mlu_ops.weight_only_quant_matmul(x,
layer.qweight,
layer.scales,
None,
bias,
residual,
"none",
self.quant_config.weight_bits)
if len(x_shape) > 2:
out = out.view(*x_shape[:-1], out.shape[-1])
return out