[Model] Support DeepSeek-V4
This commit is contained in:
37
vllm_mlu/model_executor/layers/quantization/__init__.py
Normal file
37
vllm_mlu/model_executor/layers/quantization/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QUANTIZATION_METHODS, register_quantization_config
|
||||
)
|
||||
|
||||
MLU_QUANTIZATION_METHODS= [
|
||||
"smoothquant",
|
||||
"weightonly",
|
||||
"awq_mlu",
|
||||
"gptq_mlu",
|
||||
]
|
||||
|
||||
|
||||
def register_fake_mlu_quantization_methods():
|
||||
for quant_method in MLU_QUANTIZATION_METHODS:
|
||||
if quant_method not in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.append(quant_method)
|
||||
|
||||
|
||||
def remove_fake_mlu_quantization_methods():
|
||||
for quant_method in MLU_QUANTIZATION_METHODS:
|
||||
if quant_method in QUANTIZATION_METHODS:
|
||||
QUANTIZATION_METHODS.remove(quant_method)
|
||||
|
||||
|
||||
def register_real_mlu_quantization_methods():
|
||||
remove_fake_mlu_quantization_methods()
|
||||
from vllm_mlu.model_executor.layers.quantization.weightonly import WeightOnlyConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.smoothquant import SmoothQuantConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.awq_mlu import AWQMluConfig
|
||||
from vllm_mlu.model_executor.layers.quantization.gptq_mlu import GPTQMluConfig
|
||||
register_quantization_config("weightonly")(WeightOnlyConfig)
|
||||
register_quantization_config("smoothquant")(SmoothQuantConfig)
|
||||
register_quantization_config("awq_mlu")(AWQMluConfig)
|
||||
register_quantization_config("gptq_mlu")(GPTQMluConfig)
|
||||
412
vllm_mlu/model_executor/layers/quantization/awq_mlu.py
Normal file
412
vllm_mlu/model_executor/layers/quantization/awq_mlu.py
Normal file
@@ -0,0 +1,412 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("awq_mlu")
|
||||
class AWQMluConfig(QuantizationConfig):
|
||||
"""Config class for AWQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2306.00978
|
||||
"""
|
||||
|
||||
# num_bits -> type
|
||||
TYPE_MAP = {
|
||||
4: {
|
||||
False: scalar_types.uint4b8,
|
||||
True: scalar_types.uint4,
|
||||
},
|
||||
8: {
|
||||
False: scalar_types.uint8b128,
|
||||
True: scalar_types.uint8,
|
||||
}
|
||||
}
|
||||
|
||||
VERSION = ["gemm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
zero_point: bool,
|
||||
lm_head_quantized: bool,
|
||||
version: str = "gemm",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = 32 // self.weight_bits
|
||||
self.version = version
|
||||
self.support_scale_zeros = False
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is supported for "
|
||||
f"AWQMlu, but got {self.weight_bits} bits.")
|
||||
if self.version not in self.VERSION:
|
||||
raise ValueError(
|
||||
"Currently, only gemm, gemv version is supported for "
|
||||
f"AWQMlu, but got verion:{self.version}.")
|
||||
|
||||
if self.version in ["gemm"]:
|
||||
self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]}
|
||||
self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]}
|
||||
else:
|
||||
self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"AWQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"zero_point={self.zero_point}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "awq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
|
||||
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
|
||||
zero_point = cls.get_from_keys(config, ["zero_point"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
version = cls.get_from_keys_or(config, ["version"],
|
||||
default="gemm")
|
||||
return cls(weight_bits, group_size, zero_point, lm_head_quantized, version)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["AWQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return AWQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "awq"
|
||||
or user_quant == "awq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
if can_convert and user_quant == "awq":
|
||||
logger.info("Detected that the model can run with awq_mlu"
|
||||
", however you specified quantization=awq explicitly,"
|
||||
" so forcing awq. Use quantization=awq_mlu for"
|
||||
" faster inference")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
has_zp = quant_config.get("zero_point", None)
|
||||
version = quant_config.get("version", "gemm")
|
||||
|
||||
if quant_method != "awq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or has_zp is None):
|
||||
return False
|
||||
|
||||
if num_bits not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
if version not in cls.VERSION:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp],
|
||||
group_size=group_size,
|
||||
has_zp=has_zp)
|
||||
|
||||
class AWQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: AWQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
scales = GroupQuantScaleParameter(data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
packed_qweight, scale_zeros = self.extract_autoawq(layer)
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.zero_point and not self.quant_config.support_scale_zeros:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
def extract_autoawq(self, layer: torch.nn.Module):
|
||||
qweight = layer.qweight.data
|
||||
qzeros = layer.qzeros.data
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits)
|
||||
# Reverse the order of the iweight and izeros tensors
|
||||
iweight, izeros = self.reverse_awq_order(iweight, izeros, bits)
|
||||
|
||||
# overflow checks
|
||||
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
|
||||
if izeros is not None:
|
||||
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
|
||||
|
||||
if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros):
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
# unpacking columnwise
|
||||
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
iweights = iweights.view(iweights.shape[0], -1)
|
||||
if not self.quant_config.zero_point or self.quant_config.support_scale_zeros:
|
||||
iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1)
|
||||
|
||||
# unpacking columnwise
|
||||
if qzeros is not None:
|
||||
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype)
|
||||
izeros = izeros.view(izeros.shape[0], -1)
|
||||
if not self.quant_config.zero_point:
|
||||
izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1)
|
||||
else:
|
||||
izeros = None
|
||||
|
||||
return iweights, izeros
|
||||
|
||||
def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
|
||||
reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device)
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
||||
reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]]
|
||||
reverse_order_tensor = reverse_order_tensor.view(-1)
|
||||
|
||||
rweights = iweights[:, reverse_order_tensor]
|
||||
if izeros is not None:
|
||||
rzeros = izeros[:, reverse_order_tensor]
|
||||
|
||||
return rweights, rzeros
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
753
vllm_mlu/model_executor/layers/quantization/fp8.py
Normal file
753
vllm_mlu/model_executor/layers/quantization/fp8.py
Normal file
@@ -0,0 +1,753 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import functools
|
||||
from functools import partial
|
||||
import importlib.util
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from typing import Any, Dict, List, Optional, Callable
|
||||
from vllm import envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.quantization.fp8 import (
|
||||
get_flashinfer_moe_backend,
|
||||
ACTIVATION_SCHEMES,
|
||||
Fp8Config,
|
||||
Fp8LinearMethod,
|
||||
Fp8MoeBackend,
|
||||
Fp8MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
FlashinferMoeBackend,
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
get_flashinfer_moe_backend,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
validate_fp8_block_shape
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale,
|
||||
maybe_create_device_identity, Fp8LinearOp)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter, ChannelQuantScaleParameter,
|
||||
ModelWeightParameter, PerTensorScaleParameter)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.deep_gemm import (
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.flashinfer import has_flashinfer_moe
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize
|
||||
import vllm_mlu._mlu_ops as mlu_ops
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend:
|
||||
"""
|
||||
Select the primary FP8 MoE backend
|
||||
Note: Shape-specific fallbacks may still occur at runtime.
|
||||
"""
|
||||
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and (
|
||||
current_platform.is_device_capability(100)
|
||||
or current_platform.is_device_capability(90)
|
||||
)
|
||||
and envs.VLLM_USE_FLASHINFER_MOE_FP8
|
||||
and has_flashinfer_moe()
|
||||
):
|
||||
backend = get_flashinfer_moe_backend()
|
||||
if backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_TRTLLM
|
||||
else:
|
||||
if block_quant and current_platform.is_device_capability(100):
|
||||
raise ValueError(
|
||||
"FlashInfer FP8 MoE throughput backend does not "
|
||||
"support block quantization. Please use "
|
||||
"VLLM_FLASHINFER_MOE_BACKEND=latency "
|
||||
"instead."
|
||||
)
|
||||
logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100")
|
||||
return Fp8MoeBackend.FLASHINFER_CUTLASS
|
||||
|
||||
# weight-only path for older GPUs without native FP8
|
||||
use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: disable marlin for MLU backend.
|
||||
'''
|
||||
if current_platform.is_rocm() or current_platform.is_out_of_tree():
|
||||
use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if use_marlin:
|
||||
logger.info_once("Using Marlin backend for FP8 MoE")
|
||||
return Fp8MoeBackend.MARLIN
|
||||
|
||||
# deepGEMM on supported platforms with block-quantized weights
|
||||
if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant:
|
||||
if not has_deep_gemm():
|
||||
logger.warning_once("DeepGEMM backend requested but not available.")
|
||||
elif is_deep_gemm_supported():
|
||||
logger.info_once("Using DeepGEMM backend for FP8 MoE")
|
||||
return Fp8MoeBackend.DEEPGEMM
|
||||
|
||||
# CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights
|
||||
if (
|
||||
current_platform.is_cuda()
|
||||
and current_platform.is_device_capability(100)
|
||||
and block_quant
|
||||
):
|
||||
logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE")
|
||||
return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
|
||||
# default to Triton
|
||||
logger.info_once("Using Triton backend for FP8 MoE")
|
||||
return Fp8MoeBackend.TRITON
|
||||
|
||||
|
||||
Fp8Config____init____org = Fp8Config.__init__
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized: bool = False,
|
||||
activation_scheme: str = "dynamic",
|
||||
ignored_layers: list[str] | None = None,
|
||||
weight_block_size: list[int] | None = None,
|
||||
activation_quant_method: Optional[str] = None,
|
||||
weight_quant_method: Optional[str] = None,
|
||||
) -> None:
|
||||
super(Fp8Config, self).__init__()
|
||||
|
||||
Fp8Config____init____org(
|
||||
self,
|
||||
is_checkpoint_fp8_serialized,
|
||||
activation_scheme,
|
||||
ignored_layers,
|
||||
weight_block_size
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add class members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
self.activation_quant_method = activation_quant_method
|
||||
self.weight_quant_method = weight_quant_method
|
||||
|
||||
assert (self.weight_block_size or \
|
||||
self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel"
|
||||
and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\
|
||||
"input dynamic per-token weight per-channel quantization yet."
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
@classmethod
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config(
|
||||
cls, config: Dict[str, Any]
|
||||
) -> "Fp8Config":
|
||||
quant_method = cls.get_from_keys(config, ["quant_method"])
|
||||
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
||||
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
||||
ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None)
|
||||
weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None)
|
||||
if not ignored_layers:
|
||||
ignored_layers = cls.get_from_keys_or(
|
||||
config, ["modules_to_not_convert"], None
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add config members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
activation_quant_method = cls.get_from_keys_or(config,
|
||||
["activation_quant_method"],
|
||||
'per_token')
|
||||
weight_quant_method = cls.get_from_keys_or(config,
|
||||
["weight_quant_method"],
|
||||
None)
|
||||
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
||||
activation_scheme=activation_scheme,
|
||||
ignored_layers=ignored_layers,
|
||||
weight_block_size=weight_block_size,
|
||||
activation_quant_method=activation_quant_method,
|
||||
weight_quant_method=weight_quant_method)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group.
|
||||
'''
|
||||
tp_group = extra_weight_attrs.get("tp_group", None)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add tp_group.
|
||||
'''
|
||||
# WEIGHT
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
else:
|
||||
# For non-serialized checkpoints, use original dtype
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# If checkpoint is serialized fp8, load them.
|
||||
# Otherwise, wait until process_weights_after_loading.
|
||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||
# WEIGHT SCALE
|
||||
if not self.block_quant:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Support weight per channel quantization.
|
||||
@brief: Add tp_group to enable custom split.
|
||||
'''
|
||||
if self.weight_per_channel:
|
||||
scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(sum(output_partition_sizes), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
else:
|
||||
scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
scale[:] = torch.finfo(torch.float32).min
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
layer.register_parameter("weight_scale", scale)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
scale = create_fp8_scale_parameter(
|
||||
BlockQuantScaleParameter,
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
self.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale, {"scale_type": "weight_scale"})
|
||||
# The weight_scale_inv name is intentional for deepseekv3
|
||||
layer.register_parameter("weight_scale_inv", scale)
|
||||
|
||||
# INPUT ACTIVATION SCALE
|
||||
if self.act_q_static:
|
||||
scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
set_weight_attrs(scale, {"scale_type": "input_scale"})
|
||||
layer.register_parameter("input_scale", scale)
|
||||
else:
|
||||
layer.register_parameter("input_scale", None)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__(
|
||||
self,
|
||||
quant_config: Fp8Config
|
||||
):
|
||||
self.quant_config = quant_config
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (
|
||||
not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN
|
||||
)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
if vllm_is_batch_invariant():
|
||||
self.use_marlin = False
|
||||
|
||||
# AITER is only supported on ROCm and only for FP8_FNUZ
|
||||
# and at the moment are MI300 series
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
self.use_deep_gemm = is_deep_gemm_supported()
|
||||
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant = self.weight_block_size is not None
|
||||
if self.block_quant:
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.act_q_static = self.quant_config.activation_scheme == "static"
|
||||
if self.weight_block_size:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
# Use per-token quantization for better perf if dynamic and cutlass
|
||||
if not self.act_q_static and cutlass_fp8_supported():
|
||||
self.act_q_group_shape = GroupShape.PER_TOKEN
|
||||
else:
|
||||
self.act_q_group_shape = GroupShape.PER_TENSOR
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add config members activation_quant_method and weight_quant_method to
|
||||
indicate the granularity of quantization.
|
||||
'''
|
||||
self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel')
|
||||
self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token')
|
||||
if self.weight_per_channel and self.activation_per_token:
|
||||
self.use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if self.block_quant:
|
||||
assert not self.act_q_static
|
||||
assert self.weight_block_size is not None
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.act_q_static,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
)
|
||||
|
||||
|
||||
Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading(
|
||||
self,
|
||||
layer: Module,
|
||||
) -> None:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: For dynamic activation and channel-wise weight quantization,
|
||||
additional processing is not needed.
|
||||
'''
|
||||
if (self.quant_config.is_checkpoint_fp8_serialized
|
||||
and self.weight_per_channel
|
||||
and self.quant_config.activation_scheme == "dynamic"):
|
||||
return
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer)
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert residual is None, "Fp8Linear residual is not supported yet."
|
||||
|
||||
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
|
||||
# we will use BF16 dequant when DeepGEMM is not supported.
|
||||
if vllm_is_batch_invariant():
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
else:
|
||||
# per-tensor/channel: dequant to BF16 and run GEMM
|
||||
weight_fp8 = layer.weight.to(torch.bfloat16)
|
||||
weight_scale = layer.weight_scale.to(torch.bfloat16)
|
||||
if weight_scale.numel() == 1:
|
||||
# Per-tensor: simple scalar multiplication
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
else:
|
||||
# Multiple scales (fused modules like QKV)
|
||||
# Try to infer correct broadcasting
|
||||
# weight is [K, N], scale could be [num_logical_weights]
|
||||
# Need to figure out how to broadcast - for now just try
|
||||
# direct multiplication
|
||||
if (
|
||||
weight_scale.dim() == 1
|
||||
and weight_scale.shape[0] == weight_fp8.shape[0]
|
||||
):
|
||||
# Per-row scaling
|
||||
weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1)
|
||||
else:
|
||||
# Fallback
|
||||
weight_bf16 = weight_fp8 * weight_scale
|
||||
return torch.nn.functional.linear(x, weight_bf16.t(), bias)
|
||||
|
||||
if self.use_marlin:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size is not None
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
apply_w8a8_block_fp8_linear)
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use activation per token quantization based on quantization config.
|
||||
'''
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
weight_per_channel=self.weight_per_channel,
|
||||
activation_per_token=self.activation_per_token)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__(
|
||||
self,
|
||||
quant_config: Fp8Config,
|
||||
layer: torch.nn.Module
|
||||
):
|
||||
super(Fp8MoEMethod, self).__init__(layer.moe_config)
|
||||
self.layer = layer
|
||||
self.quant_config = quant_config
|
||||
self.weight_block_size = self.quant_config.weight_block_size
|
||||
self.block_quant: bool = self.weight_block_size is not None
|
||||
self.fp8_backend = get_fp8_moe_backend(self.block_quant)
|
||||
|
||||
self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN
|
||||
self.flashinfer_moe_backend: FlashinferMoeBackend | None = None
|
||||
if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM
|
||||
elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS:
|
||||
self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS
|
||||
if self.block_quant:
|
||||
assert self.weight_block_size == [128, 128], (
|
||||
f"Only support weight_block_size == [128, 128], "
|
||||
f"got {self.weight_block_size}"
|
||||
)
|
||||
self.flashinfer_moe_fn = partial(
|
||||
flashinfer_cutlass_moe_fp8,
|
||||
moe=self.moe,
|
||||
use_deepseek_fp8_block_scale=self.block_quant,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM
|
||||
self.allow_cutlass_block_scaled_grouped_gemm = (
|
||||
self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM
|
||||
)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: In mlu, always set self.use_marlin as False.
|
||||
'''
|
||||
self.use_marlin = False
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if enable_eplb:
|
||||
assert expert_load_view is not None
|
||||
assert logical_to_physical_map is not None
|
||||
assert logical_replica_count is not None
|
||||
assert isinstance(layer, FusedMoE)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts
|
||||
'''
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
if scoring_func == "softmax":
|
||||
topk_weights, topk_ids = mlu_ops.moe_softmax_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
route_scale=routed_scaling_factor,
|
||||
)
|
||||
elif scoring_func == "sigmoid":
|
||||
topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk(
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
num_expert_group,
|
||||
topk_group,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
||||
|
||||
# gen_idx
|
||||
ori_input_shape = x.shape
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
router_logits = router_logits.reshape(-1, router_logits.size(-1))
|
||||
expert_num = router_logits.size(-1)
|
||||
tokens_num = x.size(0)
|
||||
expert_size = layer.w13_weight.size(0)
|
||||
expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx(
|
||||
topk_ids, expert_num
|
||||
)
|
||||
|
||||
expand_hidden_states = mlu_ops.moe_expand_input(
|
||||
x, expand_idx, cumsum_token_count, 0, expert_size
|
||||
)
|
||||
quant_input, input_scale = _fp8_quantize(
|
||||
expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size
|
||||
)
|
||||
gemm1_out = mlu_ops.smooth_quant_group_gemm(
|
||||
quant_input,
|
||||
layer.w13_weight,
|
||||
token_count,
|
||||
expand_idx=None,
|
||||
c=None,
|
||||
alpha=None,
|
||||
beta=None,
|
||||
a_scale=input_scale.T.contiguous(),
|
||||
b_scale=layer.w13_weight_scale_inv,
|
||||
dtype=x.dtype,
|
||||
max_m=tokens_num,
|
||||
)
|
||||
|
||||
act_out = mlu_ops.active(gemm1_out, activation, is_gated=True)
|
||||
act_out_quantize, act_out_scale = _fp8_quantize(
|
||||
act_out, A_scale=None, block_shape=self.quant_config.weight_block_size
|
||||
)
|
||||
|
||||
gemm2_out = mlu_ops.smooth_quant_group_gemm(
|
||||
act_out_quantize,
|
||||
layer.w2_weight,
|
||||
token_count,
|
||||
expand_idx=None,
|
||||
c=None,
|
||||
alpha=None,
|
||||
beta=None,
|
||||
a_scale=act_out_scale.T.contiguous(),
|
||||
b_scale=layer.w2_weight_scale_inv,
|
||||
dtype=x.dtype,
|
||||
max_m=tokens_num,
|
||||
)
|
||||
|
||||
output = mlu_ops.moe_combine_result(
|
||||
gemm2_out,
|
||||
topk_weights,
|
||||
combine_idx,
|
||||
residual=None,
|
||||
cusum_token_count=cumsum_token_count,
|
||||
start_expert_id=0,
|
||||
expert_size=expert_size,
|
||||
bias=None,
|
||||
)
|
||||
return output.view(ori_input_shape)
|
||||
|
||||
"""
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
"""
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.apply,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8Config,
|
||||
Fp8Config.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8Config____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8Config,
|
||||
Fp8Config.from_config,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.create_weights,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearMethod,
|
||||
Fp8LinearMethod.process_weights_after_loading,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8MoEMethod,
|
||||
Fp8MoEMethod.__init__,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__
|
||||
)
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8MoEMethod,
|
||||
Fp8MoEMethod.apply,
|
||||
vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply
|
||||
)
|
||||
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
440
vllm_mlu/model_executor/layers/quantization/gptq_mlu.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512]
|
||||
|
||||
# We only support gptq and awq over 300 serials and only support int4 and int8 precision
|
||||
def query_mlu_supported_quant_types(has_zp: bool,
|
||||
device_capability: Optional[int] = None
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + zero-point
|
||||
return [scalar_types.uint4, scalar_types.uint8]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
return [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
|
||||
|
||||
def check_mlu_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = current_platform.get_device_capability()
|
||||
device_capability = major * 10 + minor
|
||||
|
||||
supported_types = query_mlu_supported_quant_types(
|
||||
has_zp, device_capability)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (False, f"Mlu does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).")
|
||||
if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES):
|
||||
return (False, f"Mlu does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# @register_quantization_config("gptq_mlu")
|
||||
class GPTQMluConfig(QuantizationConfig):
|
||||
"""Config class for GPTQMlu.
|
||||
|
||||
Reference: https://arxiv.org/abs/2210.17323
|
||||
"""
|
||||
|
||||
# (num_bits, is_sym) -> quant_type
|
||||
TYPE_MAP = {
|
||||
(4, True): scalar_types.uint4b8,
|
||||
(8, True): scalar_types.uint8b128,
|
||||
(4, False): scalar_types.uint4b8,
|
||||
(8, False): scalar_types.uint8b128,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
is_sym: bool,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.group_size = group_size
|
||||
self.desc_act = desc_act
|
||||
self.is_sym = is_sym
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
self.pack_factor = Fraction(32, self.weight_bits)
|
||||
self.support_scale_zeros = False
|
||||
self.use_native = self.desc_act or (not self.is_sym and not self.support_scale_zeros)
|
||||
|
||||
if self.weight_bits not in [4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 4/8-bit weight quantization is "
|
||||
f"supported for GPTQMlu, but got {self.weight_bits} bits.")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQMluConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}),"
|
||||
f"lm_head_quantized={self.lm_head_quantized}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "gptq_mlu"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16, torch.float32]
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quant_config.json", "quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "GPTQMluConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
is_sym = cls.get_from_keys(config, ["sym"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, is_sym, lm_head_quantized)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQMluLinearMethod"]:
|
||||
if (isinstance(layer, LinearBase) or
|
||||
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
|
||||
return GPTQMluLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
@classmethod
|
||||
def is_gptq_mlu_compatible(cls, quant_config: Dict[str, Any]):
|
||||
# Extract data from quant config.
|
||||
quant_method = quant_config.get("quant_method", "").lower()
|
||||
num_bits = quant_config.get("bits", None)
|
||||
group_size = quant_config.get("group_size", None)
|
||||
sym = quant_config.get("sym", None)
|
||||
desc_act = quant_config.get("desc_act", None)
|
||||
|
||||
if quant_method != "gptq":
|
||||
return False
|
||||
|
||||
# If we cannot find the info needed in the config, cannot convert.
|
||||
if (num_bits is None or group_size is None or sym is None
|
||||
or desc_act is None):
|
||||
return False
|
||||
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
return check_mlu_supported(quant_type=cls.TYPE_MAP[(num_bits, sym)],
|
||||
group_size=group_size, has_zp=False)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
can_convert = cls.is_gptq_mlu_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
|
||||
or user_quant == "gptq_mlu")
|
||||
|
||||
if can_convert and is_valid_user_quant:
|
||||
msg = ("The model is convertible to {} during runtime."
|
||||
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
class GPTQMluLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQMlu.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQMlu quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMluConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition) and (self.quant_config.group_size !=
|
||||
-1) and (not self.quant_config.desc_act):
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.device = layer.qweight.data.device
|
||||
packed_qweight, scale_zeros = self.extract_autogptq(layer)
|
||||
if self.quant_config.use_native:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
layer.qzeros = None
|
||||
layer.scales = None
|
||||
else:
|
||||
layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False)
|
||||
if scale_zeros is not None:
|
||||
layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False)
|
||||
else:
|
||||
layer.qzeros = None
|
||||
layer.scales = torch.nn.Parameter(layer.scales.transpose(0, 1).contiguous(), requires_grad=False)
|
||||
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.quant_config.use_native:
|
||||
output = mlu_ops.matmul(x, layer.qweight, bias)
|
||||
if residual is not None:
|
||||
output = output + residual
|
||||
else:
|
||||
output = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
layer.qzeros,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def extract_autogptq(self, layer: torch.nn.Module):
|
||||
scales = layer.scales.data
|
||||
bits = self.quant_config.weight_bits
|
||||
group_size = self.quant_config.group_size
|
||||
# Unpack the qweight and qzeros tensors
|
||||
iweight = self.unpack_gptq_qweight_int32_into_int8(layer.qweight.data, bits)
|
||||
izeros = self.unpack_gptq_qzeros_int32_into_int8(layer.qzeros.data, bits)
|
||||
|
||||
if self.quant_config.use_native:
|
||||
if self.quant_config.desc_act:
|
||||
scales = torch.index_select(scales, 0, layer.g_idx)
|
||||
if izeros is not None:
|
||||
izeros = torch.index_select(izeros, 0, layer.g_idx)
|
||||
else:
|
||||
scales = scales.repeat_interleave(group_size, dim=0)
|
||||
if izeros is not None:
|
||||
izeros = izeros.repeat_interleave(group_size, dim=0)
|
||||
|
||||
if izeros is not None:
|
||||
fweight = (iweight - izeros) * scales
|
||||
else:
|
||||
fweight = iweight * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
fweight = fweight.transpose(0, 1)
|
||||
|
||||
return fweight, None
|
||||
|
||||
if not self.quant_config.is_sym and self.quant_config.support_scale_zeros and izeros is not None:
|
||||
scale_zeros = izeros.to(scales.dtype) * -1 * scales
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
scale_zeros = scale_zeros.transpose(0, 1)
|
||||
else:
|
||||
# for is_sym is true now, so make iweight to sign value and ignore qzeros
|
||||
iweight = torch.bitwise_and(iweight - 2**(bits - 1), (2 ** bits) - 1)
|
||||
scale_zeros = None
|
||||
|
||||
# transpose [ci, co] -> [co, ci]
|
||||
iweight = iweight.to(torch.int8).transpose(0, 1)
|
||||
|
||||
if bits == 4:
|
||||
higher_bit_tensor = iweight[:, 1::2]
|
||||
lower_bit_tensor = iweight[:, 0::2]
|
||||
packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor)
|
||||
else:
|
||||
packed_qweight = iweight
|
||||
|
||||
return packed_qweight, scale_zeros
|
||||
|
||||
|
||||
def unpack_gptq_qweight_int32_into_int8(self, qweight: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qweight.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
weight = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qweight, 1).expand(-1, 32 // bits, -1),
|
||||
shifts.unsqueeze(-1),
|
||||
).to(dtype)
|
||||
weight = torch.bitwise_and(weight, (2**bits) - 1)
|
||||
weight = weight.reshape(-1, weight.shape[-1])
|
||||
|
||||
return weight
|
||||
|
||||
|
||||
def unpack_gptq_qzeros_int32_into_int8(self, qzeros: torch.Tensor, bits: int):
|
||||
shifts = torch.arange(0, 32, bits, device=qzeros.device).unsqueeze(0)
|
||||
dtype = torch.int16 if bits == 8 else torch.int8
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qzeros, 2).expand(-1, -1, 32 // bits),
|
||||
shifts.unsqueeze(0),
|
||||
).to(dtype)
|
||||
|
||||
zeros = zeros + 1
|
||||
|
||||
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
||||
zeros = zeros.reshape(qzeros.shape[0], -1)
|
||||
|
||||
return zeros
|
||||
|
||||
|
||||
def combine_low_bits(self, tensor_a, tensor_b):
|
||||
"""
|
||||
Combine the lower 4 bits of two int8 tensors into a new int8 tensor.
|
||||
|
||||
Args:
|
||||
tensor_a (torch.Tensor): First tensor of type int8.
|
||||
tensor_b (torch.Tensor): Second tensor of type int8.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b.
|
||||
"""
|
||||
# 确保输入是 int8 类型
|
||||
if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8:
|
||||
raise ValueError("Both tensors must be of int8 type.")
|
||||
|
||||
# 提取每个 tensor 的低4位
|
||||
low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位
|
||||
low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位
|
||||
|
||||
# 将 tensor_a 的低4位左移4位
|
||||
shifted_low_bits_a = low_bits_a << 4
|
||||
|
||||
# 组合两个 tensor 的低4位
|
||||
combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b)
|
||||
|
||||
return combined
|
||||
337
vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
337
vllm_mlu/model_executor/layers/quantization/smoothquant.py
Executable file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
RowvLLMParameter)
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
|
||||
str_dtype_to_bits,
|
||||
is_fp8_str_dtype)
|
||||
|
||||
|
||||
# @register_quantization_config("smoothquant")
|
||||
class SmoothQuantConfig(QuantizationConfig):
|
||||
"""Config class for SmoothQuant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_mode: str, # smoothquant
|
||||
input_quant_method: str, # per token/per tensor
|
||||
group_size: int,
|
||||
weight_precision: str,
|
||||
activation_precision: str,
|
||||
only_expert_per_group: bool,
|
||||
expert_weight_precision: str,
|
||||
expert_activation_precision: str,
|
||||
force_use_weightonly_except_expert: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_mode = quant_mode
|
||||
self.input_quant_method = input_quant_method
|
||||
self.group_size = group_size
|
||||
self.weight_precision = weight_precision
|
||||
self.activation_precision = activation_precision
|
||||
self.only_expert_per_group = only_expert_per_group
|
||||
self.expert_weight_precision = expert_weight_precision
|
||||
self.expert_activation_precision = expert_activation_precision
|
||||
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
|
||||
|
||||
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
|
||||
self.weight_bits = str_dtype_to_bits(self.weight_precision)
|
||||
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
|
||||
if self.weight_precision == 'int4':
|
||||
self.weight_dtype = torch.int8
|
||||
else:
|
||||
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
|
||||
if self.expert_weight_precision == 'int4':
|
||||
self.expert_weight_dtype = torch.int8
|
||||
else:
|
||||
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
|
||||
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
|
||||
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
|
||||
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
self.expert_pack_factor = 8 // self.expert_weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
|
||||
f"quant_mode={self.quant_mode}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"weight_precision={self.weight_precision}, "
|
||||
f"activation_precision={self.activation_precision}, "
|
||||
f"only_expert_per_group={self.only_expert_per_group}, "
|
||||
f"expert_weight_precision={self.expert_weight_precision}, "
|
||||
f"expert_activation_precision={self.expert_activation_precision}, "
|
||||
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
return "SmoothQuant"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
|
||||
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
|
||||
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
|
||||
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
|
||||
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
|
||||
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
|
||||
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
|
||||
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
|
||||
|
||||
if expert_weight_precision is None:
|
||||
expert_weight_precision = weight_precision
|
||||
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
|
||||
weight_precision = 'int8'
|
||||
|
||||
if expert_activation_precision is None:
|
||||
expert_activation_precision = activation_precision
|
||||
|
||||
return cls(quant_mode=quant_mode,
|
||||
input_quant_method=input_quant_method,
|
||||
group_size=group_size,
|
||||
weight_precision=weight_precision,
|
||||
activation_precision=activation_precision,
|
||||
only_expert_per_group=only_expert_per_group,
|
||||
expert_weight_precision=expert_weight_precision,
|
||||
expert_activation_precision=expert_activation_precision,
|
||||
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return SmoothQuantLinearMethod(self, prefix)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class SmoothQuantLinearMethod(LinearMethodBase):
|
||||
"""Linear method for SmoothQuant.
|
||||
|
||||
Args:
|
||||
quant_config: The SmoothQuant quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
|
||||
self.quant_config = quant_config
|
||||
# for per-tensor case, we can skip quant input for the first attn|ffn linear
|
||||
# and fusion this step in layernorm to get better performance
|
||||
self.skip_quant_input = False
|
||||
self.compute_dtype = torch.get_default_dtype()
|
||||
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
|
||||
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
|
||||
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
|
||||
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
|
||||
|
||||
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
|
||||
self.is_group_quant = True
|
||||
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
|
||||
self.is_group_quant = True
|
||||
else:
|
||||
self.is_group_quant = False
|
||||
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
|
||||
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor != 0):
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
group_num = 1
|
||||
if self.is_group_quant:
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
f"The input size {input_size_per_partition} is not aligned with the quantized "
|
||||
f"weight shape. This can be caused by too large "
|
||||
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
|
||||
|
||||
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||||
if input_size_per_partition != input_size:
|
||||
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||||
|
||||
qweight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
device="mlu",
|
||||
dtype=self.weight_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.is_group_quant:
|
||||
per_channel_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
group_num,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
per_channel_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("per_channel_scale", per_channel_scale)
|
||||
|
||||
if self.has_smooth:
|
||||
smooth = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(smooth, {
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("smooth", smooth)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
scale_to_int = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=torch.float32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
set_weight_attrs(scale_to_int, {
|
||||
"ignore_warning": True,
|
||||
})
|
||||
layer.register_parameter("scale_to_int", scale_to_int)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if self.has_smooth and layer.smooth.dtype != torch.float:
|
||||
layer.smooth = layer.smooth.to(torch.float)
|
||||
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
|
||||
layer.scale_to_int = layer.scale_to_int.to(torch.float)
|
||||
if layer.per_channel_scale.dtype != torch.float:
|
||||
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
|
||||
|
||||
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
|
||||
if self.has_smooth:
|
||||
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
|
||||
if self.quant_config.input_quant_method == "per_tensor":
|
||||
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
use_tp_weight : bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
layer_smooth = layer.smooth if self.has_smooth else None
|
||||
layer_qweight = layer.qweight
|
||||
layer_per_channel_scale = layer.per_channel_scale
|
||||
if use_tp_weight:
|
||||
if hasattr(layer, 'tp_smooth'):
|
||||
layer_smooth = layer.tp_smooth
|
||||
if hasattr(layer, 'tp_qweight'):
|
||||
layer_qweight = layer.tp_qweight
|
||||
if hasattr(layer, 'tp_per_channel_scale'):
|
||||
layer_per_channel_scale = layer.tp_per_channel_scale
|
||||
|
||||
quant_input = None
|
||||
if self.skip_quant_input:
|
||||
quant_input = x
|
||||
elif self.quant_config.input_quant_method == "per_token":
|
||||
if self.is_fp8:
|
||||
quant_input, input_scale = mlu_ops.scaled_quantize(x,
|
||||
layer_smooth,
|
||||
quant_type=self.weight_dtype,
|
||||
quant_mode='dynamic_per_token')
|
||||
else:
|
||||
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
|
||||
elif self.quant_config.input_quant_method == "per_tensor":
|
||||
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Currently, only per_token or per_tensor input quantization is supported for "
|
||||
f"SmoothQuant, but got {self.input_quant_method}.")
|
||||
quant_input_shape = quant_input.shape
|
||||
if len(quant_input_shape) > 2:
|
||||
quant_input = quant_input.view(-1, quant_input_shape[-1])
|
||||
input_scale = input_scale.view(-1)
|
||||
if residual is not None and len(residual.shape) > 2:
|
||||
residual = residual.view(-1, residual.shape[-1])
|
||||
if self.is_fp8:
|
||||
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
|
||||
layer_per_channel_scale,
|
||||
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
|
||||
bias,
|
||||
c=residual, act_mode="none",quant_bit_size=8,
|
||||
alpha=1.0, beta=1.0, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||||
if output is not None:
|
||||
out = out.view(output.shape)
|
||||
output.copy_(out)
|
||||
out = output
|
||||
else:
|
||||
if output is not None:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||||
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
|
||||
else:
|
||||
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||||
layer_per_channel_scale, self.compute_dtype, bias, residual)
|
||||
if len(quant_input_shape) > 2:
|
||||
out = out.view(*quant_input_shape[:-1], out.shape[-1])
|
||||
return out
|
||||
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
import torch
|
||||
|
||||
QUANTIZATION_CHOICES = ['int8', 'int4', 'e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
|
||||
INTERGER_DTYPES = [torch.uint8, torch.uint16, torch.uint32, torch.uint64, torch.int8, torch.int16, torch.short,
|
||||
torch.int32, torch.int, torch.int64, torch.long]
|
||||
FLOAT_DTYPES = [torch.float32, torch.float, torch.float64, torch.double, torch.float16, torch.bfloat16,
|
||||
torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.half]
|
||||
FP8_DTYPE = [torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz]
|
||||
FP8_STR_DTYPE = ['e4m3fn', 'e4m3fnuz', 'e5m2', 'e5m2fnuz']
|
||||
GEMM_GROUP_SIZE = [64, 128, 256, 512]
|
||||
|
||||
_STR_TO_TORCH_DTYPE_DICT = dict(
|
||||
bfloat16=torch.bfloat16,
|
||||
float16=torch.float16,
|
||||
float32=torch.float32,
|
||||
int64=torch.int64,
|
||||
int32=torch.int32,
|
||||
int8=torch.int8,
|
||||
bool=torch.bool,
|
||||
e4m3fn=torch.float8_e4m3fn,
|
||||
e4m3fnuz=torch.float8_e4m3fnuz,
|
||||
e5m2=torch.float8_e5m2,
|
||||
e5m2fnuz=torch.float8_e5m2fnuz,
|
||||
)
|
||||
|
||||
TORCH_DTYPE_TO_STR_DICT = {
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float16: "float16",
|
||||
torch.float32: "float32",
|
||||
torch.int64: "int64",
|
||||
torch.int32: "int32",
|
||||
torch.int8: "int8",
|
||||
torch.bool: "bool",
|
||||
torch.float8_e4m3fn: "e4m3fn",
|
||||
torch.float8_e4m3fnuz: "e4m3fnuz",
|
||||
torch.float8_e5m2: "e5m2",
|
||||
torch.float8_e5m2fnuz: "e5m2fnuz",
|
||||
}
|
||||
|
||||
STR_DTYPE_TO_BITS_DICT = {
|
||||
"bfloat16": 16,
|
||||
"float16": 16,
|
||||
"float32": 32,
|
||||
"int64": 64,
|
||||
"int32": 32,
|
||||
"int8": 8,
|
||||
'int4': 4,
|
||||
"bool": 1,
|
||||
"e4m3fn": 8,
|
||||
"e4m3fnuz": 8,
|
||||
"e5m2": 8,
|
||||
"e5m2fnuz": 8,
|
||||
}
|
||||
|
||||
|
||||
def str_dtype_to_torch(str_dtype: str):
|
||||
'''
|
||||
convert torch dytpe to str dtype
|
||||
'''
|
||||
ret = _STR_TO_TORCH_DTYPE_DICT.get(str_dtype)
|
||||
dtype = ret if ret is not None else torch.float16
|
||||
return dtype
|
||||
|
||||
|
||||
def torch_dtype_to_str(dtype: torch.dtype):
|
||||
'''
|
||||
convert torch dytpe to str dtype
|
||||
'''
|
||||
ret = TORCH_DTYPE_TO_STR_DICT.get(dtype)
|
||||
str_dtype = ret if ret is not None else "float16"
|
||||
return str_dtype
|
||||
|
||||
|
||||
def str_dtype_to_bits(str_dtype):
|
||||
'''
|
||||
convert torch dtype to bits size
|
||||
'''
|
||||
ret = STR_DTYPE_TO_BITS_DICT.get(str_dtype)
|
||||
bits = ret if ret is not None else 8
|
||||
return bits
|
||||
|
||||
|
||||
def is_integer_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
check whether is integer or not
|
||||
'''
|
||||
return dtype in INTERGER_DTYPES
|
||||
|
||||
|
||||
def is_float_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
check whether is float or not
|
||||
'''
|
||||
return dtype in FLOAT_DTYPES
|
||||
|
||||
|
||||
def is_fp8_dtype(dtype: torch.dtype):
|
||||
'''
|
||||
judge fp8 torch dtype
|
||||
'''
|
||||
return dtype in FP8_DTYPE
|
||||
|
||||
|
||||
def is_fp8_str_dtype(str_dtype: str):
|
||||
'''
|
||||
judge fp8 str dtype
|
||||
'''
|
||||
return str_dtype in FP8_STR_DTYPE
|
||||
424
vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py
Normal file
424
vllm_mlu/model_executor/layers/quantization/utils/fp8_utils.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from https://github.com/sgl-project/sglang/pull/2575
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_BLOCK_FP8_SUPPORTED)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_per_token_group_quant_fp8_colmajor)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: get total core for split triton kernel
|
||||
'''
|
||||
|
||||
import triton.backends.mlu.driver as driver
|
||||
|
||||
_devprob = driver.BangUtils().get_device_properties(torch.mlu.current_device())
|
||||
TOTAL_CLUSTER_NUM = _devprob.get("cluster_num")
|
||||
TOTAL_CORE_NUM = TOTAL_CLUSTER_NUM * _devprob.get("core_num_per_cluster")
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED,
|
||||
use_aiter_and_is_supported: bool = False,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
shape_supported_by_cutlass = (weight.shape[0] % 128 == 0
|
||||
and weight.shape[1] % 128 == 0)
|
||||
if current_platform.is_rocm():
|
||||
# TODO this is never used, as cutlass_block_fp8_supported is False
|
||||
scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) +
|
||||
input_2d.shape[:-1])[::-1]
|
||||
scale_b_shape = (weight_scale.view(-1, 1)
|
||||
if weight_scale.dim() <= 1 else weight_scale.T).shape
|
||||
ar, ac = scale_a_shape
|
||||
br, bc = scale_b_shape
|
||||
if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0])
|
||||
or br not in (1, weight.shape[0])):
|
||||
shape_supported_by_cutlass = False
|
||||
if cutlass_block_fp8_supported and shape_supported_by_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True)
|
||||
output = ops.cutlass_scaled_mm(q_input,
|
||||
weight.T,
|
||||
out_dtype=input.dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale.T)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=False)
|
||||
output = w8a8_block_fp8_matmul(q_input,
|
||||
weight,
|
||||
x_scale,
|
||||
weight_scale,
|
||||
block_size,
|
||||
output_dtype=input.dtype)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
|
||||
def per_token_group_quant_fp8(
|
||||
x: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
column_major_scales: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Function to perform per-token-group quantization on an input tensor `x`.
|
||||
It converts the tensor values into signed float8 values and returns the
|
||||
quantized tensor along with the scaling factor used for quantization.
|
||||
Args:
|
||||
x: The input tensor with ndim >= 2.
|
||||
group_size: The group size used for quantization.
|
||||
eps: The minimum to avoid dividing zero.
|
||||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
|
||||
is supported for now.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
|
||||
scaling factor for quantization.
|
||||
"""
|
||||
dtype = current_platform.fp8_dtype() if dtype is None else dtype
|
||||
assert (x.shape[-1] % group_size == 0), (
|
||||
f"the last dimension of `x` {x.shape[-1]} must be divisible "
|
||||
f"by `group_size` {group_size}")
|
||||
assert x.stride(-1) == 1, "`x` groups must be contiguous"
|
||||
|
||||
finfo = torch.finfo(dtype)
|
||||
fp8_min = finfo.min
|
||||
fp8_max = finfo.max
|
||||
|
||||
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
|
||||
M = x.numel() // group_size
|
||||
N = group_size
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: split for limit the memory usage(65536)
|
||||
'''
|
||||
group_per_block = 1
|
||||
while M >= 65536:
|
||||
group_per_block *= 2
|
||||
M = x.numel() // (group_size * group_per_block)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
if column_major_scales:
|
||||
shape = (x.shape[-1] // group_size, ) + x.shape[:-1]
|
||||
x_s = torch.empty(shape, device=x.device,
|
||||
dtype=torch.float32).permute(-1, -2)
|
||||
else:
|
||||
shape = x.shape[:-1] + (x.shape[-1] // group_size, )
|
||||
x_s = torch.empty(shape, device=x.device, dtype=torch.float32)
|
||||
|
||||
BLOCK = triton.next_power_of_2(N)
|
||||
# heuristics for number of warps
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: set num_warps to 1 for triton-mlu
|
||||
'''
|
||||
num_warps = 1
|
||||
num_stages = 1
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
if column_major_scales:
|
||||
_per_token_group_quant_fp8_colmajor[(M, )](
|
||||
x,
|
||||
x_q,
|
||||
x_s,
|
||||
group_size,
|
||||
x.shape[1],
|
||||
x.stride(0),
|
||||
x_s.stride(1),
|
||||
eps,
|
||||
fp8_min=fp8_min,
|
||||
fp8_max=fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
else:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replaced the 'scaled_quantize' kernel from the 'tmo' library with
|
||||
'_per_token_group_quant_fp8' kernel
|
||||
'''
|
||||
# Check if x is contiguous, if not, create a new tensor for contiguous x
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
x_origin_shape = x.shape
|
||||
x = x.reshape(*x.shape[:-1], -1, group_size)
|
||||
x_q, x_s = mlu_ops.scaled_quantize(x,
|
||||
None,
|
||||
quant_type=dtype,
|
||||
quant_mode='dynamic_per_token')
|
||||
x_q = x_q.reshape(x_origin_shape)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return x_q, x_s
|
||||
|
||||
@triton.jit
|
||||
def _w8a8_block_fp8_matmul(
|
||||
# Pointers to inputs and output
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
# Shape for matmul
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
# Block size for block-wise quantization
|
||||
group_n,
|
||||
group_k,
|
||||
# Stride for inputs and output
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
stride_As_m,
|
||||
stride_As_k,
|
||||
stride_Bs_k,
|
||||
stride_Bs_n,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
GROUP_SIZE_M: tl.constexpr,
|
||||
):
|
||||
"""Triton-accelerated function used to perform linear operations (dot
|
||||
product) on input tensors `A` and `B` with block-wise quantization, and
|
||||
store the result in output tensor `C`.
|
||||
"""
|
||||
|
||||
pid = tl.program_id(axis=0)
|
||||
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: split for limit the memory usage(65536)
|
||||
'''
|
||||
num_block_size_all = num_pid_m * num_pid_n
|
||||
num_block_size_per = num_block_size_all // tl.num_programs(axis=0)
|
||||
num_block_size_rem = num_block_size_all % tl.num_programs(axis=0)
|
||||
|
||||
core_deal_num_block_size = num_block_size_per + (pid < num_block_size_rem)
|
||||
core_deal_num_block_start = num_block_size_per * pid + min(num_block_size_rem, pid)
|
||||
|
||||
for pid_i in range(0, core_deal_num_block_size):
|
||||
pid_in_core_deal_block = core_deal_num_block_start + pid_i
|
||||
|
||||
num_pid_in_group = GROUP_SIZE_M * num_pid_n
|
||||
group_id = pid_in_core_deal_block // num_pid_in_group
|
||||
first_pid_m = group_id * GROUP_SIZE_M
|
||||
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
|
||||
pid_m = first_pid_m + (pid_in_core_deal_block % group_size_m)
|
||||
pid_n = (pid_in_core_deal_block % num_pid_in_group) // group_size_m
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
||||
|
||||
As_ptrs = As + offs_am * stride_As_m
|
||||
offs_bsn = offs_bn // group_n
|
||||
Bs_ptrs = Bs + offs_bsn * stride_Bs_n
|
||||
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
|
||||
other=0.0)
|
||||
|
||||
k_start = k * BLOCK_SIZE_K
|
||||
offs_ks = k_start // group_k
|
||||
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
|
||||
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
|
||||
|
||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
if C.dtype.element_ty == tl.bfloat16:
|
||||
c = accumulator.to(tl.bfloat16)
|
||||
elif C.dtype.element_ty == tl.float16:
|
||||
c = accumulator.to(tl.float16)
|
||||
else:
|
||||
c = accumulator.to(tl.float32)
|
||||
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
"""This function performs matrix multiplication with block-wise
|
||||
quantization.
|
||||
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
|
||||
The output is returned in the specified `output_dtype`.
|
||||
Args:
|
||||
A: The input tensor, e.g., activation.
|
||||
B: The input tensor, e.g., weight.
|
||||
As: The per-token-group quantization scale for `A`.
|
||||
Bs: The per-block quantization scale for `B`.
|
||||
block_size: The block size for per-block quantization. It should
|
||||
be 2-dim, e.g., [128, 128].
|
||||
output_dytpe: The dtype of the returned tensor.
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: replaced the 'scaled_matmul' kernel from the 'tmo' library with
|
||||
'_w8a8_block_fp8_matmul' kernel
|
||||
'''
|
||||
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert B.ndim == 2 and Bs.ndim == 2
|
||||
|
||||
if (B.shape[0] % 128 == 0) and (B.shape[1] % 128 == 0):
|
||||
C = mlu_ops.scaled_matmul(A, B, As, Bs, output_dtype, bias=None, c=None, act_mode="none",
|
||||
quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||||
else:
|
||||
# NOTE(wulingchao): scaled_matmul 底层算子只支持n和k是128的倍数
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2 and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N, )
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
|
||||
# BLOCK_SIZE_K must be divisible by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 1,
|
||||
"num_stages": 1,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (TOTAL_CORE_NUM, )
|
||||
|
||||
_w8a8_block_fp8_matmul[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
return C
|
||||
178
vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
178
vllm_mlu/model_executor/layers/quantization/utils/w8a8_utils.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, USE_ROWWISE_TORCH_SCALED_MM, cutlass_w8a8_scaled_mm,
|
||||
flashinfer_w8a8_scaled_mm, rocm_per_tensor_w8a8_scaled_mm,
|
||||
torch_per_tensor_w8a8_scaled_mm, torch_per_token_w8a8_scaled_mm,
|
||||
torch_channelwise_w8a8_scaled_mm)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
|
||||
def mlu_w8a8_scaled_mm(
|
||||
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
|
||||
output_shape: list, **kwargs
|
||||
) -> torch.Tensor:
|
||||
output = mlu_ops.scaled_matmul(
|
||||
qinput, # a
|
||||
weight, # b
|
||||
scale_a, # a_scale
|
||||
scale_b, # b_scale
|
||||
out_dtype, # output_dtype
|
||||
bias, # bias
|
||||
c=None, act_mode="none",quant_bit_size=8, alpha=1, beta=1, use_hp_active=False,
|
||||
a_quant_bit_size=8, a_calib=None, b_calib=None
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
|
||||
|
||||
def dispatch_w8a8_scaled_mm(
|
||||
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool,
|
||||
weight_per_channel: bool, activation_per_token: bool
|
||||
) -> Callable[..., torch.Tensor]:
|
||||
if per_tensor_weights and per_tensor_activations:
|
||||
if preferred_backend == "rocm":
|
||||
return rocm_per_tensor_w8a8_scaled_mm
|
||||
if preferred_backend == "flashinfer":
|
||||
return flashinfer_w8a8_scaled_mm
|
||||
if preferred_backend == "cutlass":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
return torch_per_tensor_w8a8_scaled_mm
|
||||
|
||||
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
|
||||
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
|
||||
return cutlass_w8a8_scaled_mm
|
||||
|
||||
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
|
||||
if (
|
||||
not per_tensor_weights
|
||||
and not per_tensor_activations
|
||||
and USE_ROWWISE_TORCH_SCALED_MM
|
||||
):
|
||||
return torch_per_token_w8a8_scaled_mm
|
||||
# Normally, torch.scaled_mm supports per tensor weights + activations only
|
||||
# so fallback to naive if per channel or per token
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: dispatch to mlu_w8a8_scaled_mm
|
||||
'''
|
||||
if weight_per_channel and activation_per_token:
|
||||
return mlu_w8a8_scaled_mm
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return torch_channelwise_w8a8_scaled_mm
|
||||
|
||||
|
||||
def vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out_dtype: torch.dtype | None = None,
|
||||
input_scale: torch.Tensor | None = None,
|
||||
input_scale_ub: torch.Tensor | None = None,
|
||||
bias: torch.Tensor | None = None,
|
||||
weight_per_channel: bool = True,
|
||||
activation_per_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_scale computed from x.
|
||||
# If static, layer.input_scale is scalar and x_scale is input_scale.
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add mlu_fp8_supported
|
||||
'''
|
||||
self.mlu_fp8_supported = False
|
||||
if weight_per_channel and activation_per_token:
|
||||
self.mlu_fp8_supported = True
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
# View input as 2D matrix for fp8 methods
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[1]]
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = input.dtype
|
||||
|
||||
if self.mlu_fp8_supported:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Add support for activation-per-token weight-per-channel quantization.
|
||||
'''
|
||||
qinput, x_scale = mlu_ops.scaled_quantize(
|
||||
input_2d,# x
|
||||
None, # scale
|
||||
None, # zero
|
||||
None, # scale_ub
|
||||
quant_type=torch.float8_e4m3fn,
|
||||
quant_mode='dynamic_per_token'
|
||||
)
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
else:
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
if input.dtype != current_platform.fp8_dtype():
|
||||
qinput, x_scale = self.quant_fp8(
|
||||
input_2d,
|
||||
input_scale,
|
||||
input_scale_ub,
|
||||
)
|
||||
else:
|
||||
qinput, x_scale = input_2d, input_scale
|
||||
|
||||
# Must have dim() conditions
|
||||
# In per-token quant scenario, when the number of token is 1,
|
||||
# the scale will only have 1 elements.
|
||||
# Without checking the dim(),
|
||||
# we cannot distingushes between per-tensor and per-token quant.
|
||||
# Example:
|
||||
# When the number of token is 1, per-token scale is [[1]]
|
||||
# When per-tensor scale is [1] or ().
|
||||
per_tensor_weights = weight_scale.numel() == 1
|
||||
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
|
||||
|
||||
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
|
||||
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
|
||||
self.preferred_backend, per_tensor_weights, per_tensor_activations,
|
||||
weight_per_channel, activation_per_token)
|
||||
return w8a8_scaled_mm_func(
|
||||
qinput=qinput,
|
||||
weight=weight,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=x_scale,
|
||||
scale_b=weight_scale,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Fp8LinearOp,
|
||||
Fp8LinearOp.apply,
|
||||
vllm__model_executor__layers__quantization__utils__w8a8_util__Fp8LinearOp__apply
|
||||
)
|
||||
150
vllm_mlu/model_executor/layers/quantization/weightonly.py
Executable file
150
vllm_mlu/model_executor/layers/quantization/weightonly.py
Executable file
@@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import register_quantization_config
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# @register_quantization_config("weightonly")
|
||||
class WeightOnlyConfig(QuantizationConfig):
|
||||
"""Config class for WeightOnly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
quant_mode: str, # weight_only
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weight_bits = weight_bits
|
||||
self.quant_mode = quant_mode
|
||||
|
||||
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
|
||||
raise ValueError(
|
||||
"Currently, only 8/4-bit weight quantization is supported for "
|
||||
f"weight_only, but got {self.weight_bits} bits.")
|
||||
self.pack_factor = 8 // self.weight_bits
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
|
||||
f"quant_mode={self.quant_mode})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "WeightOnly"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
try:
|
||||
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||||
except Exception:
|
||||
quant_mode = "WeightOnly"
|
||||
return cls(weight_bits, quant_mode)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return WeightOnlyLinearMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||||
|
||||
|
||||
class WeightOnlyLinearMethod(LinearMethodBase):
|
||||
"""Linear method for WeightOnly.
|
||||
|
||||
Args:
|
||||
quant_config: The WeightOnly quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: WeightOnlyConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> Dict[str, Any]:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if self.quant_config.quant_mode == "WeightOnly":
|
||||
scale_and_zero_input_dim = None
|
||||
if output_size != output_size_per_partition:
|
||||
scale_and_zero_input_dim = 0
|
||||
qweight = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
device="mlu",
|
||||
dtype=torch.int8,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(qweight, {
|
||||
"input_dim": 1,
|
||||
"output_dim": 0,
|
||||
})
|
||||
scales = Parameter(
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
device="mlu",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(scales, {
|
||||
"input_dim": scale_and_zero_input_dim,
|
||||
"output_dim": 0,
|
||||
})
|
||||
layer.register_parameter("qweight", qweight)
|
||||
set_weight_attrs(qweight, extra_weight_attrs)
|
||||
layer.register_parameter("scales", scales)
|
||||
set_weight_attrs(scales, extra_weight_attrs)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if layer.scales.dtype != torch.float:
|
||||
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
if len(x_shape) > 2:
|
||||
x = x.view(-1, x_shape[-1])
|
||||
out = mlu_ops.weight_only_quant_matmul(x,
|
||||
layer.qweight,
|
||||
layer.scales,
|
||||
None,
|
||||
bias,
|
||||
residual,
|
||||
"none",
|
||||
self.quant_config.weight_bits)
|
||||
if len(x_shape) > 2:
|
||||
out = out.view(*x_shape[:-1], out.shape[-1])
|
||||
return out
|
||||
Reference in New Issue
Block a user