Add support for Qwen3 MoE+GPTQ
This commit is contained in:
@@ -4,25 +4,43 @@
|
||||
import enum
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Any, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
get_linear_quant_method)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter)
|
||||
get_linear_quant_method,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
@@ -35,7 +53,10 @@ class GPTQConfig(QuantizationConfig):
|
||||
group_size: int,
|
||||
desc_act: bool,
|
||||
lm_head_quantized: bool,
|
||||
dynamic: dict[str, dict[str, Union[int, bool]]],
|
||||
dynamic: dict[str, dict[str, int | bool]],
|
||||
autoround_version: str = "",
|
||||
modules_in_block_to_quantize: list[str] | None = None,
|
||||
checkpoint_format: str = "",
|
||||
) -> None:
|
||||
# GPTQModel use `dynamic` config property to allow per module
|
||||
# quantization config so each module can be individually optimized.
|
||||
@@ -71,23 +92,44 @@ class GPTQConfig(QuantizationConfig):
|
||||
if self.weight_bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(
|
||||
"Currently, only 2/3/4/8-bit weight quantization is "
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits.")
|
||||
f"supported for GPTQ, but got {self.weight_bits} bits."
|
||||
)
|
||||
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
|
||||
# For now, show a warning, since gptq_marlin will be used by default.
|
||||
if self.weight_bits == 4:
|
||||
logger.warning_once(
|
||||
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
|
||||
"Please switch to gptq_marlin or gptq_bitblas."
|
||||
)
|
||||
|
||||
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
|
||||
|
||||
# used to identify GPTQ model quantized by autoround
|
||||
self.autoround_version = autoround_version
|
||||
|
||||
# GPTQ v1 and v2 format deals with zero points differently.
|
||||
# Currently GPTQModel stores v1 format checkpoints by default,
|
||||
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
|
||||
self.checkpoint_format = checkpoint_format
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}), "
|
||||
f"dynamic={self.dynamic}")
|
||||
return (
|
||||
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
||||
f"group_size={self.group_size}, "
|
||||
f"desc_act={self.desc_act}), "
|
||||
f"lm_head_quantized={self.lm_head_quantized}, "
|
||||
f"dynamic={self.dynamic}, "
|
||||
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
|
||||
f"checkpoint_format={self.checkpoint_format})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half, torch.bfloat16]
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
@@ -106,18 +148,77 @@ class GPTQConfig(QuantizationConfig):
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
desc_act = cls.get_from_keys(config, ["desc_act"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
|
||||
default=False)
|
||||
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
|
||||
dynamic)
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
autoround_version = cls.get_from_keys_or(
|
||||
config, ["autoround_version"], default=""
|
||||
)
|
||||
modules_in_block_to_quantize = cls.get_from_keys_or(
|
||||
config, ["modules_in_block_to_quantize"], default=None
|
||||
)
|
||||
checkpoint_format = cls.get_from_keys_or(
|
||||
config, ["checkpoint_format"], default=""
|
||||
)
|
||||
return cls(
|
||||
weight_bits,
|
||||
group_size,
|
||||
desc_act,
|
||||
lm_head_quantized,
|
||||
dynamic,
|
||||
autoround_version,
|
||||
modules_in_block_to_quantize,
|
||||
checkpoint_format,
|
||||
)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None:
|
||||
if isinstance(layer, FusedMoE):
|
||||
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
|
||||
from .moe_wna16 import MoeWNA16Config
|
||||
|
||||
print("Using MoeWNA16Config for GPTQ MoE layer quantization.")
|
||||
# TODO: maybe update this for GPTQv2 format checkpoints
|
||||
config = {
|
||||
"quant_method": "gptq",
|
||||
"bits": self.weight_bits,
|
||||
"group_size": self.group_size,
|
||||
"sym": True, # GPTQ typically uses symmetric quantization
|
||||
"lm_head": False,
|
||||
}
|
||||
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["GPTQLinearMethod"]:
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_in_block_to_quantize:
|
||||
if is_list_of(self.modules_in_block_to_quantize, list):
|
||||
# original modules_in_block_to_quantize: list[list[str]]
|
||||
# flatten original modules_in_block_to_quantize
|
||||
self.modules_in_block_to_quantize = [
|
||||
item
|
||||
for sublist in self.modules_in_block_to_quantize
|
||||
for item in sublist
|
||||
]
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_in_block_to_quantize = list(quant_layers)
|
||||
|
||||
|
||||
class ExllamaState(Enum):
|
||||
|
||||
UNUSED = enum.auto()
|
||||
UNINITIALIZED = enum.auto()
|
||||
READY = enum.auto()
|
||||
@@ -133,6 +234,9 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
# GPTQ v1 and v2 format deals with zero points differently
|
||||
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -149,14 +253,15 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
raise ValueError(
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
"tensor parallel size."
|
||||
)
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if (output_size_per_partition % self.quant_config.pack_factor.numerator
|
||||
!= 0):
|
||||
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size.")
|
||||
"tensor parallel size."
|
||||
)
|
||||
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
@@ -165,8 +270,10 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
exllama_state = ExllamaState.UNINITIALIZED
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1):
|
||||
if (
|
||||
input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1
|
||||
):
|
||||
# For act-order models, we cannot use Exllama for row parallel layer
|
||||
if self.quant_config.desc_act:
|
||||
exllama_state = ExllamaState.UNUSED
|
||||
@@ -185,56 +292,56 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader)
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
g_idx = RowvLLMParameter(data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
qzeros_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
"data": torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data":
|
||||
torch.empty(
|
||||
"data": torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader":
|
||||
weight_loader
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1,
|
||||
**weight_scale_args)
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(output_dim=1,
|
||||
input_dim=0,
|
||||
**weight_scale_args)
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args)
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
@@ -252,79 +359,23 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if self.quant_config.group_size == 128 or self.quant_config.group_size == 64:
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.g_idx.data = torch.empty(
|
||||
(0,), dtype=torch.int, device=layer.g_idx.device
|
||||
)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
|
||||
if layer.scales.dtype != torch.bfloat16:
|
||||
perm_space = torch.empty(0)
|
||||
temp_space = torch.empty(0)
|
||||
if self.quant_config.weight_bits == 4:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space, temp_space,
|
||||
False)
|
||||
if self.quant_config.weight_bits == 8:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space, temp_space,
|
||||
False)
|
||||
else:
|
||||
if layer.exllama_state == ExllamaState.UNINITIALIZED:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty((0, ),
|
||||
dtype=torch.int,
|
||||
device=layer.g_idx.device)
|
||||
layer.exllama_state = ExllamaState.READY
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx,
|
||||
self.quant_config.weight_bits)
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
||||
|
||||
"""
|
||||
perm_space = torch.empty(0)
|
||||
if self.quant_config.weight_bits == 4:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space)
|
||||
if self.quant_config.weight_bits == 8:
|
||||
# warmup
|
||||
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
|
||||
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
self.quant_config.weight_bits,
|
||||
self.quant_config.group_size,
|
||||
perm_space)
|
||||
"""
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
perm_space = torch.empty(0)
|
||||
@@ -334,11 +385,12 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
if self.quant_config.desc_act:
|
||||
perm_space = torch.empty(reshaped_x.shape[0], reshaped_x.shape[1],
|
||||
dtype=torch.float16, device="cuda")
|
||||
|
||||
|
||||
if reshaped_x.dtype == torch.bfloat16:
|
||||
temp_space = torch.zeros(reshaped_x.shape[0], layer.qweight.shape[1],
|
||||
dtype=torch.float32, device="cuda")
|
||||
|
||||
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
|
||||
# and require different gemm kernels.
|
||||
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
|
||||
layer.scales, layer.g_idx,
|
||||
layer.exllama_state == ExllamaState.READY,
|
||||
@@ -348,4 +400,4 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
True if reshaped_x.dtype == torch.bfloat16 else False)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
return output.reshape(out_shape)
|
||||
@@ -298,6 +298,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: torch.Tensor | None = None,
|
||||
logical_to_physical_map: torch.Tensor | None = None,
|
||||
logical_replica_count: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
|
||||
Reference in New Issue
Block a user