528 lines
20 KiB
Python
528 lines
20 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import fnmatch
|
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
|
import torch
|
|
|
|
from vllm.attention.layer import Attention
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.linear import (
|
|
LinearBase,
|
|
LinearMethodBase,
|
|
UnquantizedLinearMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
|
QuantizationConfig,
|
|
QuantizeMethodBase,
|
|
)
|
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
|
QuarkMoEMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization.quark.schemes import (
|
|
QuarkOCP_MX,
|
|
QuarkScheme,
|
|
QuarkW8A8Fp8,
|
|
QuarkW8A8Int8,
|
|
)
|
|
from vllm.model_executor.layers.quantization.quark.utils import (
|
|
deep_compare,
|
|
should_ignore_layer,
|
|
)
|
|
from vllm.model_executor.models.utils import WeightsMapper
|
|
from vllm.platforms import current_platform
|
|
|
|
if TYPE_CHECKING:
|
|
from vllm.model_executor.models.utils import WeightsMapper
|
|
|
|
__all__ = ["QuarkLinearMethod"]
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class QuarkConfig(QuantizationConfig):
|
|
def __init__(
|
|
self,
|
|
quant_config: dict[str, Any],
|
|
kv_cache_group: list[str] | None = None,
|
|
kv_cache_config: dict[str, Any] | None = None,
|
|
pack_method: str = "reorder",
|
|
):
|
|
super().__init__()
|
|
if kv_cache_group is None:
|
|
kv_cache_group = []
|
|
self.quant_config = quant_config
|
|
self.kv_cache_group = kv_cache_group
|
|
self.kv_cache_config = kv_cache_config
|
|
self.pack_method = pack_method
|
|
|
|
def get_linear_method(self) -> "QuarkLinearMethod":
|
|
return QuarkLinearMethod(self)
|
|
|
|
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
|
return [torch.float16, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 70
|
|
|
|
def get_name(self) -> QuantizationMethods:
|
|
return "quark"
|
|
|
|
def apply_vllm_mapper( # noqa: B027
|
|
self, hf_to_vllm_mapper: "WeightsMapper"
|
|
):
|
|
"""
|
|
Interface for models to update module names referenced in
|
|
quantization configs in order to reflect the vllm model structure
|
|
|
|
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
|
|
structure of the qconfig) to vllm model structure
|
|
"""
|
|
quant_config_with_hf_to_vllm_mapper = {}
|
|
|
|
for k, v in self.quant_config.items():
|
|
if isinstance(v, list):
|
|
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_list(v)
|
|
elif isinstance(v, dict):
|
|
quant_config_with_hf_to_vllm_mapper[k] = hf_to_vllm_mapper.apply_dict(v)
|
|
else:
|
|
if isinstance(v, str):
|
|
mapped_v_list = hf_to_vllm_mapper.apply_list([v])
|
|
if mapped_v_list:
|
|
quant_config_with_hf_to_vllm_mapper[k] = mapped_v_list[0]
|
|
else:
|
|
quant_config_with_hf_to_vllm_mapper[k] = v
|
|
|
|
self.quant_config = quant_config_with_hf_to_vllm_mapper
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional["QuantizeMethodBase"]:
|
|
# Check if the layer is skipped for quantization.
|
|
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
|
if should_ignore_layer(
|
|
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
|
):
|
|
return UnquantizedLinearMethod()
|
|
if isinstance(layer, LinearBase):
|
|
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
|
layer.scheme = scheme
|
|
return QuarkLinearMethod(self)
|
|
if isinstance(layer, Attention):
|
|
return QuarkKVCacheMethod(self)
|
|
|
|
if isinstance(layer, FusedMoE):
|
|
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
|
return None
|
|
|
|
@classmethod
|
|
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
|
export_config = config.get("export")
|
|
if export_config is None:
|
|
raise ValueError(
|
|
"The export key should be included in "
|
|
"the configurations of Quark quantized model"
|
|
)
|
|
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
|
pack_method = cast(str, export_config.get("pack_method"))
|
|
|
|
# In the export model of quark, the quantization configuration
|
|
# of kv_cache is stored in layer_quant_config. First, it is
|
|
# judged whether kv_cache_group exists, and then it is judged
|
|
# whether layer_quant_config has a quantization configuration
|
|
# that matches kv_cache.
|
|
if len(kv_cache_group) == 0:
|
|
kv_cache_config = None
|
|
else:
|
|
kv_cache_set = set(kv_cache_group)
|
|
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
|
|
layer_quant_names = list(layer_quant_config.keys())
|
|
layer_quant_set = set(layer_quant_names)
|
|
|
|
if not (
|
|
kv_cache_set.issubset(layer_quant_set)
|
|
or any(
|
|
fnmatch.fnmatchcase(layer_quant, pat)
|
|
for layer_quant in list(layer_quant_set)
|
|
for pat in list(kv_cache_set)
|
|
)
|
|
):
|
|
raise ValueError(
|
|
"The Quark quantized model has the "
|
|
"kv_cache_group parameter setting, "
|
|
"but no kv_cache quantization settings "
|
|
"were found in the quantization "
|
|
"configuration."
|
|
)
|
|
|
|
q_configs = [
|
|
quant_cfg
|
|
for name, quant_cfg in layer_quant_config.items()
|
|
if any(fnmatch.fnmatchcase(name, pattern) for pattern in kv_cache_group)
|
|
]
|
|
|
|
if not all(
|
|
deep_compare(q_config["output_tensors"], q_configs[0]["output_tensors"])
|
|
for q_config in q_configs
|
|
):
|
|
raise ValueError(
|
|
"The quantization method used for kv_cache should "
|
|
"be the same, but the quantization method for the "
|
|
"kv_cache layer in the config is different."
|
|
)
|
|
kv_cache_config = q_configs[0].get("output_tensors")
|
|
if kv_cache_config is None:
|
|
raise ValueError("The kv_cache quantization configuration is empty.")
|
|
|
|
# Since we have already set kv_cache quantization configurations,
|
|
# we will remove the quantization configuration for the
|
|
# output_tensors corresponding to the kv_cache layer.
|
|
for q_config in q_configs:
|
|
q_config["output_tensors"] = None
|
|
|
|
# In case q_proj output is also quantized, remove the configuration
|
|
# to keep qkv consistency.
|
|
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
|
|
if q_proj_q_config is not None:
|
|
q_proj_q_config["output_tensors"] = None
|
|
|
|
return cls(
|
|
quant_config=config,
|
|
kv_cache_group=kv_cache_group,
|
|
kv_cache_config=kv_cache_config,
|
|
pack_method=pack_method,
|
|
)
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> list[str]:
|
|
return []
|
|
|
|
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
|
capability_tuple = current_platform.get_device_capability()
|
|
|
|
if capability_tuple is not None:
|
|
capability = capability_tuple.to_int()
|
|
supported = capability >= min_capability
|
|
if error and not supported:
|
|
raise RuntimeError(
|
|
"Quantization scheme is not supported for ",
|
|
f"the current GPU. Min capability: {min_capability}. ",
|
|
f"Current capability: {capability}.",
|
|
)
|
|
return supported
|
|
else:
|
|
return False
|
|
|
|
def _is_fp8_w8a8(
|
|
self,
|
|
weight_quant: dict[str, Any] | None,
|
|
input_quant: dict[str, Any] | None,
|
|
) -> bool:
|
|
# Confirm weights and input quantized.
|
|
if weight_quant is None or input_quant is None:
|
|
return False
|
|
|
|
# Confirm weight scheme is supported
|
|
is_fp8_dtype = (
|
|
weight_quant.get("dtype") == "fp8_e4m3"
|
|
and input_quant.get("dtype") == "fp8_e4m3"
|
|
)
|
|
is_static_weight = not weight_quant.get("is_dynamic")
|
|
is_per_tensor_or_channel_weight = weight_quant.get("qscheme") in [
|
|
"per_tensor",
|
|
"per_channel",
|
|
]
|
|
|
|
if not (is_fp8_dtype and is_static_weight and is_per_tensor_or_channel_weight):
|
|
return False
|
|
|
|
# Dynamic quantization is always supported if weights supported.
|
|
if input_quant.get("is_dynamic"):
|
|
return True
|
|
|
|
# Confirm activation scheme is supported.
|
|
is_per_tensor_activation = input_quant.get("qscheme") == "per_tensor"
|
|
return is_per_tensor_activation
|
|
|
|
def _is_static_tensor_w8a8(
|
|
self,
|
|
weight_quant: dict[str, Any] | None,
|
|
input_quant: dict[str, Any] | None,
|
|
) -> bool:
|
|
# Confirm weights and input quantized.
|
|
if weight_quant is None or input_quant is None:
|
|
return False
|
|
|
|
is_int8_dtype = (
|
|
weight_quant.get("dtype") == "int8" and input_quant.get("dtype") == "int8"
|
|
)
|
|
|
|
is_tensor = (
|
|
weight_quant.get("qscheme") in ["per_tensor", "per_channel"]
|
|
and input_quant.get("qscheme") == "per_tensor"
|
|
)
|
|
|
|
is_static = not weight_quant.get("is_dynamic") and not input_quant.get(
|
|
"is_dynamic"
|
|
)
|
|
|
|
is_weight_symmetric = weight_quant.get("symmetric") is True
|
|
|
|
# Both symmetric and asymmetric input quantization supported.
|
|
# Only symmetric weight quantization supported.
|
|
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
|
|
|
def _is_ocp_mx(
|
|
self,
|
|
weight_quant: dict[str, Any] | None,
|
|
input_quant: dict[str, Any] | None,
|
|
) -> bool:
|
|
# Confirm weights and input quantized.
|
|
if weight_quant is None or input_quant is None:
|
|
logger.debug(
|
|
"Quark model is not in OCP MX format: "
|
|
"weight_quant or input_quant not set"
|
|
)
|
|
return False
|
|
|
|
# Input and weight qscheme needs to be per group.
|
|
if (
|
|
weight_quant.get("qscheme") != "per_group"
|
|
or input_quant.get("qscheme") != "per_group"
|
|
):
|
|
logger.debug("Quark model is not in OCP MX format: not per_group")
|
|
return False
|
|
|
|
# Input and weight group size needs to be 32.
|
|
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
|
|
logger.debug("Quark model is not in OCP MX format: not group_size=32")
|
|
return False
|
|
|
|
# Activations and weight scales need to be in e8m0 format.
|
|
if (
|
|
weight_quant.get("scale_format") != "e8m0"
|
|
or input_quant.get("scale_format") != "e8m0"
|
|
):
|
|
logger.debug("Quark model is not in OCP MX format: not scale_format e8m0")
|
|
return False
|
|
|
|
# Input and weight dtypes need to be any of fp4,
|
|
# fp6_e3m2 or fp6_e3m2, possibly mixed.
|
|
if weight_quant.get("dtype") not in {
|
|
"fp4",
|
|
"fp6_e3m2",
|
|
"fp6_e2m3",
|
|
} or input_quant.get("dtype") not in {"fp4", "fp6_e3m2", "fp6_e2m3"}:
|
|
logger.debug(
|
|
"Quark model is not in OCP MX format: dtype not fp4, fp6_e3m2, fp6_e2m3"
|
|
)
|
|
return False
|
|
|
|
return True
|
|
|
|
def _find_matched_config(
|
|
self, layer_name: str, module: torch.nn.Module
|
|
) -> dict[str, Any]:
|
|
proj_name = layer_name.split(".")[-1]
|
|
if proj_name in self.packed_modules_mapping:
|
|
shard_proj_names = self.packed_modules_mapping[proj_name]
|
|
|
|
# Convert fused_name --> [shard_names]
|
|
shard_names = [
|
|
layer_name.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in shard_proj_names
|
|
]
|
|
shard_configs = [
|
|
self._find_matched_config(shard_name, module)
|
|
for shard_name in shard_names
|
|
]
|
|
if not all(
|
|
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
|
|
):
|
|
raise ValueError(
|
|
f"Found a different quantization configuration for "
|
|
f"{shard_proj_names} in {layer_name}. vLLM "
|
|
"requires all to use the same scheme."
|
|
)
|
|
return shard_configs[0]
|
|
else:
|
|
layer_quant_config = cast(
|
|
dict[str, Any], self.quant_config.get("layer_quant_config")
|
|
)
|
|
|
|
def _matches_pattern(layer_name, pattern):
|
|
if "*" not in pattern:
|
|
return layer_name in pattern
|
|
return fnmatch.fnmatch(layer_name, pattern)
|
|
|
|
for name_pattern, config in layer_quant_config.items():
|
|
if _matches_pattern(layer_name, name_pattern):
|
|
return config
|
|
|
|
layer_type = cast(str, type(module))
|
|
layer_type_quant_config = cast(
|
|
dict[str, Any], self.quant_config.get("layer_type_quant_config")
|
|
)
|
|
if layer_type in layer_type_quant_config:
|
|
return layer_type_quant_config[layer_type]
|
|
|
|
global_quant_config = cast(
|
|
dict[str, Any], self.quant_config.get("global_quant_config")
|
|
)
|
|
return global_quant_config
|
|
|
|
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
|
if config.get("output_tensors") or config.get("bias"):
|
|
raise NotImplementedError(
|
|
"Currently, Quark models with output_tensors "
|
|
"and bias quantized are not supported"
|
|
)
|
|
weight_config = cast(dict[str, Any], config.get("weight"))
|
|
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
|
|
|
if self._is_fp8_w8a8(weight_config, input_config):
|
|
is_fp8_w8a8_supported = self._check_scheme_supported(
|
|
QuarkW8A8Fp8.get_min_capability(), error=False
|
|
)
|
|
if is_fp8_w8a8_supported:
|
|
return QuarkW8A8Fp8(weight_config, input_config)
|
|
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
|
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
|
return QuarkW8A8Int8(
|
|
qscheme=weight_qscheme,
|
|
is_static_input_scheme=True,
|
|
input_symmetric=input_config.get("symmetric"),
|
|
)
|
|
elif self._is_ocp_mx(weight_config, input_config):
|
|
return QuarkOCP_MX(weight_config, input_config)
|
|
|
|
raise NotImplementedError(
|
|
"No quark compatible scheme was found. "
|
|
f"Weight config: {weight_config}, "
|
|
f"Input config: {input_config}"
|
|
)
|
|
|
|
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
|
layer_quant_config = self._find_matched_config(layer_name, layer)
|
|
|
|
# Find the quant_scheme
|
|
scheme = self._get_scheme_from_config(layer_quant_config)
|
|
# Raise error if device does not support the scheme
|
|
# (e.g. fp8 needs ada lovelace)
|
|
self._check_scheme_supported(scheme.get_min_capability())
|
|
|
|
return scheme
|
|
|
|
def get_cache_scale(self, name: str) -> str | None:
|
|
"""
|
|
Check whether the param name matches the format for k/v cache scales
|
|
in quark. If this is the case, return its equivalent param name
|
|
expected by vLLM
|
|
|
|
:param name: param name
|
|
:return: matching param name for KV cache scale in vLLM
|
|
"""
|
|
if name.endswith(".output_scale") and ".k_proj" in name:
|
|
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
|
if name.endswith(".output_scale") and ".v_proj" in name:
|
|
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
|
if name.endswith(".output_scale") and ".q_proj" in name:
|
|
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
|
if name.endswith("self_attn.prob_output_scale"):
|
|
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
|
|
|
# If no matches, return None
|
|
return None
|
|
|
|
|
|
class QuarkLinearMethod(LinearMethodBase):
|
|
def __init__(self, quantization_config: QuarkConfig):
|
|
self.quantization_config = quantization_config
|
|
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
|
layer.scheme.process_weights_after_loading(layer)
|
|
|
|
def create_weights(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
input_size_per_partition: int,
|
|
output_partition_sizes: list[int],
|
|
input_size: int,
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
**extra_weight_attrs,
|
|
):
|
|
"""
|
|
Use the CompressedTensorsScheme associated with each layer to create
|
|
the necessary parameters for the layer. See LinearMethodBase for param
|
|
details
|
|
"""
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
|
layer.scheme.create_weights(
|
|
layer=layer,
|
|
input_size=input_size,
|
|
input_size_per_partition=input_size_per_partition,
|
|
output_partition_sizes=output_partition_sizes,
|
|
output_size=output_size,
|
|
params_dtype=params_dtype,
|
|
weight_loader=weight_loader,
|
|
)
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: torch.Tensor | None = None,
|
|
):
|
|
"""
|
|
Use the output of create_weights and the CompressedTensorsScheme
|
|
associated with the layer to apply the forward pass with the
|
|
layer input. See LinearMethodBase for param details
|
|
|
|
"""
|
|
scheme = layer.scheme
|
|
if scheme is None:
|
|
raise ValueError("A scheme must be defined for each layer")
|
|
|
|
return scheme.apply_weights(layer, x, bias=bias)
|
|
|
|
|
|
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
|
"""
|
|
Supports loading kv-cache scaling factors from quark checkpoints.
|
|
"""
|
|
|
|
def __init__(self, quant_config: QuarkConfig):
|
|
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
|
super().__init__(quant_config)
|
|
|
|
@staticmethod
|
|
def validate_kv_cache_config(kv_cache_config: dict[str, Any] | None):
|
|
"""
|
|
Validator for the kv cache configuration. Useful for controlling the
|
|
kv cache quantization schemes, that are being supported in vLLM
|
|
:param kv_cache_config: the quark kv cache scheme
|
|
"""
|
|
if kv_cache_config is None:
|
|
return
|
|
|
|
dtype = kv_cache_config.get("dtype")
|
|
if dtype != "fp8_e4m3":
|
|
raise NotImplementedError(
|
|
"Currently supported kv cache quantization is "
|
|
f"dtype=fp8_e4m3, however received {dtype}"
|
|
)
|
|
|
|
qscheme = kv_cache_config.get("qscheme")
|
|
if qscheme != "per_tensor":
|
|
raise NotImplementedError(
|
|
"Only support per-tensor scaling factor "
|
|
"for quark KV cache. "
|
|
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
|
|
)
|