Add support for Qwen3 MoE+GPTQ

This commit is contained in:
2025-11-15 20:14:45 +08:00
parent b296c44ae0
commit 8152e24cb2
35 changed files with 6468 additions and 574 deletions

View File

@@ -4,25 +4,43 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
get_linear_quant_method,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
from vllm.transformers_utils.config import get_safetensors_params_metadata
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
else:
QuantizationMethods = str
logger = init_logger(__name__)
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
@@ -35,7 +53,10 @@ class GPTQConfig(QuantizationConfig):
group_size: int,
desc_act: bool,
lm_head_quantized: bool,
dynamic: dict[str, dict[str, Union[int, bool]]],
dynamic: dict[str, dict[str, int | bool]],
autoround_version: str = "",
modules_in_block_to_quantize: list[str] | None = None,
checkpoint_format: str = "",
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
@@ -71,23 +92,44 @@ class GPTQConfig(QuantizationConfig):
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
f"supported for GPTQ, but got {self.weight_bits} bits."
)
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
# For now, show a warning, since gptq_marlin will be used by default.
if self.weight_bits == 4:
logger.warning_once(
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
"Please switch to gptq_marlin or gptq_bitblas."
)
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version
# GPTQ v1 and v2 format deals with zero points differently.
# Currently GPTQModel stores v1 format checkpoints by default,
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
self.checkpoint_format = checkpoint_format
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}), "
f"dynamic={self.dynamic}")
return (
f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
f"checkpoint_format={self.checkpoint_format})"
)
@classmethod
def get_name(cls) -> QuantizationMethods:
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
return [torch.half]
@classmethod
# Need to figure it out
@@ -106,18 +148,77 @@ class GPTQConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized,
dynamic)
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
autoround_version = cls.get_from_keys_or(
config, ["autoround_version"], default=""
)
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None
)
checkpoint_format = cls.get_from_keys_or(
config, ["checkpoint_format"], default=""
)
return cls(
weight_bits,
group_size,
desc_act,
lm_head_quantized,
dynamic,
autoround_version,
modules_in_block_to_quantize,
checkpoint_format,
)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Union["GPTQLinearMethod", "QuantizeMethodBase"] | None:
if isinstance(layer, FusedMoE):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from .moe_wna16 import MoeWNA16Config
print("Using MoeWNA16Config for GPTQ MoE layer quantization.")
# TODO: maybe update this for GPTQv2 format checkpoints
config = {
"quant_method": "gptq",
"bits": self.weight_bits,
"group_size": self.group_size,
"sym": True, # GPTQ typically uses symmetric quantization
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_in_block_to_quantize:
if is_list_of(self.modules_in_block_to_quantize, list):
# original modules_in_block_to_quantize: list[list[str]]
# flatten original modules_in_block_to_quantize
self.modules_in_block_to_quantize = [
item
for sublist in self.modules_in_block_to_quantize
for item in sublist
]
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name, revision=revision)
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get("dtype", None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_in_block_to_quantize = list(quant_layers)
class ExllamaState(Enum):
UNUSED = enum.auto()
UNINITIALIZED = enum.auto()
READY = enum.auto()
@@ -133,6 +234,9 @@ class GPTQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
# GPTQ v1 and v2 format deals with zero points differently
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
def create_weights(
self,
layer: torch.nn.Module,
@@ -149,14 +253,15 @@ class GPTQLinearMethod(LinearMethodBase):
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
"tensor parallel size."
)
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
"tensor parallel size."
)
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
@@ -165,8 +270,10 @@ class GPTQLinearMethod(LinearMethodBase):
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
if (
input_size != input_size_per_partition
and self.quant_config.group_size != -1
):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
@@ -185,56 +292,56 @@ class GPTQLinearMethod(LinearMethodBase):
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
weight_loader=weight_loader,
)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(
data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
qzeros_args = {
"data":
torch.empty(
"data": torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
"weight_loader": weight_loader,
}
weight_scale_args = {
"data":
torch.empty(
"data": torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
"weight_loader": weight_loader,
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
**qzeros_args,
)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
**qzeros_args,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
@@ -252,79 +359,23 @@ class GPTQLinearMethod(LinearMethodBase):
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if self.quant_config.group_size == 128 or self.quant_config.group_size == 64:
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.g_idx.data = torch.empty(
(0,), dtype=torch.int, device=layer.g_idx.device
)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
if layer.scales.dtype != torch.bfloat16:
perm_space = torch.empty(0)
temp_space = torch.empty(0)
if self.quant_config.weight_bits == 4:
# warmup
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
self.quant_config.group_size,
perm_space, temp_space,
False)
if self.quant_config.weight_bits == 8:
# warmup
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
self.quant_config.group_size,
perm_space, temp_space,
False)
else:
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
"""
perm_space = torch.empty(0)
if self.quant_config.weight_bits == 4:
# warmup
reshaped_x = torch.randn(1, layer.qweight.shape[0]*8, dtype=layer.scales.dtype, device="cuda")
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
self.quant_config.group_size,
perm_space)
if self.quant_config.weight_bits == 8:
# warmup
reshaped_x = torch.randn(1, layer.qweight.shape[0]*4, dtype=layer.scales.dtype, device="cuda")
_ = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits,
self.quant_config.group_size,
perm_space)
"""
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1], )
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1])
perm_space = torch.empty(0)
@@ -334,11 +385,12 @@ class GPTQLinearMethod(LinearMethodBase):
if self.quant_config.desc_act:
perm_space = torch.empty(reshaped_x.shape[0], reshaped_x.shape[1],
dtype=torch.float16, device="cuda")
if reshaped_x.dtype == torch.bfloat16:
temp_space = torch.zeros(reshaped_x.shape[0], layer.qweight.shape[1],
dtype=torch.float32, device="cuda")
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
# and require different gemm kernels.
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
@@ -348,4 +400,4 @@ class GPTQLinearMethod(LinearMethodBase):
True if reshaped_x.dtype == torch.bfloat16 else False)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)
return output.reshape(out_shape)

View File

@@ -298,6 +298,10 @@ class MoeWNA16Method(FusedMoEMethodBase):
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: torch.Tensor | None = None,
logical_to_physical_map: torch.Tensor | None = None,
logical_replica_count: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
assert activation == "silu", "Only SiLU activation is supported."