This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

View File

@@ -0,0 +1,914 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
import torch
from compressed_tensors.config import (
CompressionFormat,
SparsityCompressionConfig,
SparsityStructure,
)
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
LinearBase,
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS,
WNA16_SUPPORTED_BITS,
CompressedTensors24,
CompressedTensorsScheme,
CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Fp8,
CompressedTensorsW4A8Int,
CompressedTensorsW4A16Fp4,
CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8,
CompressedTensorsWNA16,
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod,
get_linear_transform_schemes,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
find_matched_target,
is_activation_quantization_format,
should_ignore_layer,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
)
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, QuantizationArgs] | None]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(
self,
target_scheme_map: dict[str, Any],
ignore: list[str],
quant_format: str,
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
sparsity_ignore_list: list[str],
kv_cache_scheme: dict[str, Any] | None = None,
config: dict[str, Any] | None = None,
transform_config: dict[str, Any] | None = None,
):
super().__init__()
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
self.sparsity_scheme_map = sparsity_scheme_map
self.sparsity_ignore_list = sparsity_ignore_list
self.config = config
if transform_config:
self.transform_config = TransformConfig.model_validate(transform_config)
else:
self.transform_config = None
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.float32, torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map)
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
self.sparsity_scheme_map
)
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
self.sparsity_ignore_list
)
if self.kv_cache_scheme is not None:
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(self.kv_cache_scheme)
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
input_tfms, output_tfms = get_linear_transform_schemes(
layer, prefix, self.transform_config, self.packed_modules_mapping
)
# choose quantization method
quant_method: LinearMethodBase = UnquantizedLinearMethod()
if quant_scheme is not None:
layer.scheme = quant_scheme
quant_method = CompressedTensorsLinearMethod(self)
# choose transform method
if any((input_tfms, output_tfms)):
return CompressedTensorsLinearTransformMethod.from_schemes(
quant_method, quant_scheme, input_tfms, output_tfms
)
else:
return quant_method
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
return None
@classmethod
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
ignore: list[str] = cast(list[str], config.get("ignore", []))
quant_format = cast(str, config.get("format"))
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
config=config
)
transform_config = config.get("transform_config")
return cls(
target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
sparsity_scheme_map=sparsity_scheme_map,
sparsity_ignore_list=sparsity_ignore_list,
config=config,
transform_config=transform_config,
)
@classmethod
def _parse_sparsity_config(
cls, config: dict[str, Any]
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A tuple with two elements
1. A dictionary mapping target layer names to their corresponding
sparsity_config
2. A list of layer names to ignore for sparsity
"""
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
return dict(), []
sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
target: sparsity_config for target in sparsity_config.targets or list()
}
sparsity_ignore_list = sparsity_config.ignore or list()
return sparse_scheme_map, sparsity_ignore_list
@classmethod
def _quantization_scheme_map_from_config(
cls, config: dict[str, Any]
) -> QUANTIZATION_SCHEME_MAP_TYPE:
"""
:param config: The `quantization_config` dictionary from config.json
:return: A dictionary mapping target layer names to their corresponding
quantization_args for weights and input activations
"""
target_scheme_map: dict[str, Any] = dict()
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
config_groups = config.get("config_groups", dict())
for _, quant_config in config_groups.items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
quant_config.get("weights")
)
target_scheme_map[target]["input_activations"] = None
target_scheme_map[target]["format"] = quant_config.get("format")
format = target_scheme_map[target].get("format")
# If no per-config format defined, use global format in config
act_quant_format = (
is_activation_quantization_format(format)
if format is not None
else is_activation_quantization_format(quant_format)
)
# TODO(czhu): w4a8fp8 is in packed-quantized format
# but needs input activation quantization
input_activations = quant_config.get("input_activations")
if act_quant_format or input_activations:
# The only case where we have activation quant supported
# but no input_activations provided in the config
# should be w8a16fp8 w8a16fp8 can also run for cases where
# there is an input_quant but it is ignored
if not input_activations:
assert (
target_scheme_map[target]["weights"].type
== QuantizationType.FLOAT
)
else:
target_scheme_map[target]["input_activations"] = (
QuantizationArgs.model_validate(
quant_config.get("input_activations")
)
)
return target_scheme_map
@classmethod
def get_config_filenames(cls) -> list[str]:
return []
def _check_scheme_supported(
self, min_capability: int, error: bool = True, match_exact: bool = False
) -> bool:
capability_tuple = current_platform.get_device_capability()
if capability_tuple is not None:
capability = capability_tuple.to_int()
if match_exact:
supported = capability == min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
"the current GPU. Required capability: ",
f"{min_capability}. Current capability: {capability}.",
)
else:
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.",
)
return supported
else:
return False
def _is_fp4a4_nvfp4(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
):
if weight_quant is None or input_quant is None:
return False
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric and input_quant.symmetric
is_group_size_16 = (
weight_quant.group_size == 16 and input_quant.group_size == 16
)
is_float_type = (
weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT
)
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
return (
is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)
def _is_fp4a16_nvfp4(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
):
is_weight_only = weight_quant is not None and input_quant is None
is_tensor_group_quant = (
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
)
is_symmetric = weight_quant.symmetric
is_group_size_16 = weight_quant.group_size == 16
is_float_type = weight_quant.type == QuantizationType.FLOAT
is_4_bits = weight_quant.num_bits == 4
return (
is_weight_only
and is_tensor_group_quant
and is_float_type
and is_4_bits
and is_group_size_16
and is_symmetric
)
def _is_static_tensor_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_tensor = (
weight_strategy
and input_quant.strategy == QuantizationStrategy.TENSOR.value
)
is_static = not weight_quant.dynamic and not input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
def _is_dynamic_token_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_token = (
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
def _is_dynamic_token_w4a8_int(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.GROUP.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
)
is_token = (
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return (
is_weight_4_bits
and is_activation_8_bits
and is_token
and weight_quant.symmetric
and is_dynamic
)
def _is_fp8_w8a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False
# Confirm weight scheme is supported.
is_floating_point = (
weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT
)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_tensor_or_channel_or_block_weight = weight_quant.strategy in [
QuantizationStrategy.TENSOR,
QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK,
]
if not (
is_floating_point
and is_symmetric_weight
and is_static_weight
and is_tensor_or_channel_or_block_weight
):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.dynamic:
return True
# Confirm activation scheme is supported.
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w4a8(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
if not weight_quant or not input_quant:
return False
is_weight_4_bits = weight_quant.num_bits == 4
is_activation_8_bits = input_quant.num_bits == 8
weight_strategy = weight_quant.strategy == QuantizationStrategy.GROUP.value
is_token = (
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
is_symmetric = weight_quant.symmetric and input_quant.symmetric
# Only per-group symmetric weight (4bit)
# + per-tok symmetric activation (8bit) quantization supported.
return (
is_weight_4_bits
and is_activation_8_bits
and is_token
and is_symmetric
and is_dynamic
)
def _is_fp8_w4a8_sm90(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
return self._check_scheme_supported(
90, error=False, match_exact=True
) and self._is_fp8_w4a8(weight_quant, input_quant)
def _is_fp8_w8a8_sm90(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
return self._check_scheme_supported(
90, error=False, match_exact=True
) and self._is_fp8_w8a8(weight_quant, input_quant)
def _is_fp8_w8a8_sm100(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
return self._check_scheme_supported(
100, error=False, match_exact=True
) and self._is_fp8_w8a8(weight_quant, input_quant)
def _is_fp8_w8a16(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
# Confirm we have floating points.
if weight_quant.type != QuantizationType.FLOAT:
return False
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_tensor_or_channel_or_block_weight = weight_quant.strategy in [
QuantizationStrategy.TENSOR,
QuantizationStrategy.CHANNEL,
QuantizationStrategy.BLOCK,
]
return (
is_symmetric_weight
and is_static_weight
and is_tensor_or_channel_or_block_weight
)
def _is_wNa16_group_channel(
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
) -> bool:
input_quant_none = input_quant is None
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value
)
is_static = not weight_quant.dynamic
return is_channel_group and input_quant_none and is_static
def _get_scheme_from_parts(
self,
weight_quant: QuantizationArgs,
input_quant: QuantizationArgs,
format: str | None = None,
) -> "CompressedTensorsScheme":
# use the per-layer format if defined, otherwise, use global format
format = format if format is not None else self.quant_format
# Detect If Mixed Precision
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
return CompressedTensorsW4A16Fp4()
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
return CompressedTensorsW4A8Fp8(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder,
)
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (
format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
):
assert weight_quant.symmetric
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size,
)
if (
format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
symmetric=weight_quant.symmetric,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder,
)
act_quant_format = is_activation_quantization_format(format)
if act_quant_format:
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return CompressedTensorsW4A4Fp4()
else:
logger.warning_once(
"Current platform does not support cutlass NVFP4."
" Running CompressedTensorsW4A16Fp4."
)
return CompressedTensorsW4A16Fp4(has_input_global_scale=True)
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
weight_quant=weight_quant,
is_static_input_scheme=(
input_quant and not input_quant.dynamic
),
)
else:
# note: input_quant will be present for converted models;
# will be ignored during inference post loading
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=not input_quant.dynamic,
)
# note: input_quant can be None
if self._is_fp8_w8a16(weight_quant, input_quant):
is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=is_static_input_scheme,
)
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric,
)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric,
)
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW4A8Int(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_quant.symmetric,
)
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
def get_scheme(
self, layer: torch.nn.Module, layer_name: str | None = None
) -> Optional["CompressedTensorsScheme"]:
"""
compressed-tensors supports non uniform in the following way:
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
Detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for inference.
"""
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@kylesayrs): support ignore module names with ct matching utils
if should_ignore_layer(
layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
):
return None
# Will be empty for models with only sparsity
weight_quant = input_quant = None
if self.target_scheme_map:
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys(),
fused_mapping=self.packed_modules_mapping,
)
scheme_dict = self.target_scheme_map[matched_target]
weight_quant = scheme_dict.get("weights")
input_quant = scheme_dict.get("input_activations")
format = scheme_dict.get("format")
# Find the sparsity scheme of the layer
# assume that fused layers inherit first component's sparsity scheme
sparsity_targets = self.sparsity_scheme_map.keys() - set(
self.sparsity_ignore_list
)
sparsity_scheme: SparsityCompressionConfig | None = None
with suppress(ValueError):
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=sparsity_targets,
fused_mapping=self.packed_modules_mapping,
)
sparsity_scheme = self.sparsity_scheme_map[matched_target]
if self.supports_cutlass_24(
weight_quant=weight_quant,
input_quant=input_quant,
sparsity_scheme=sparsity_scheme,
):
# Have a valid sparsity scheme
# Validate layer is supported by Cutlass 2:4 Kernel
model_compression_config = (
None
if sparsity_scheme is None or sparsity_scheme.format == "dense"
else self.config
)
scheme = CompressedTensors24(
quantized=weight_quant is not None or input_quant is not None,
weight_quant=weight_quant,
input_quant=input_quant,
model_compression_config=model_compression_config,
)
elif weight_quant is None:
logger.warning_once(
"Acceleration for non-quantized schemes is "
"not supported by Compressed Tensors. "
"Falling back to UnquantizedLinearMethod"
)
return None
else:
# Find the quant_scheme
scheme = self._get_scheme_from_parts( # type: ignore
weight_quant=weight_quant, input_quant=input_quant, format=format
)
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
return scheme
def get_cache_scale(self, name: str) -> str | None:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def has_blocked_weights(self) -> bool:
for scheme in self.target_scheme_map.values():
weight_quant = scheme.get("weights")
if (
weight_quant is not None
and weight_quant.strategy == QuantizationStrategy.BLOCK
):
return True
return False
@staticmethod
def supports_cutlass_24(
weight_quant: QuantizationArgs | None,
input_quant: QuantizationArgs | None,
sparsity_scheme: SparsityCompressionConfig | None = None,
) -> bool:
"""
Check if the layer is supported by the Cutlass 2:4 Kernel
Conditions:
- Overarching condition: Sparsity Structure is 2:4
- Unquantized cases are supported
- Weight only quantization is not-supported
- Supported weight quantization strategies are TENSOR and CHANNEL
- Supported input quantization strategies are TENSOR and TOKEN
- Only 8 bit quantization is supported
:return: True if the layer is supported by the Cutlass 2:4 Kernel
False otherwise
"""
if sparsity_scheme is None:
return False
is_valid_sparsity_structure: bool = (
sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value
)
valid_compressors = {
CompressionFormat.dense.value,
CompressionFormat.sparse_24_bitmask.value,
}
is_valid_sparsity = (
is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors
)
if not is_valid_sparsity:
return False
# Unquantized cases are supported
if weight_quant is None and input_quant is None:
return True
# Weight only quantization is not-supported
if weight_quant is not None and input_quant is None:
return False
supported_weight_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.CHANNEL.value,
]
assert weight_quant is not None
assert input_quant is not None
if weight_quant.strategy not in supported_weight_quant_strategies:
return False
supported_input_quant_strategies = [
QuantizationStrategy.TENSOR.value,
QuantizationStrategy.TOKEN.value,
]
if input_quant.strategy not in supported_input_quant_strategies:
return False
return weight_quant.num_bits == input_quant.num_bits == 8
class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}"
)
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}"
)
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,35 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
from .compressed_tensors_w4a16_24 import (
W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24,
)
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
# This avoids circular import error
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
"CompressedTensors24",
"CompressedTensorsW4A16Fp4",
"CompressedTensorsW4A4Fp4",
"CompressedTensorsW4A8Int",
"CompressedTensorsW4A8Fp8",
]

View File

@@ -0,0 +1,392 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any
import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)
from compressed_tensors.utils import combine_shards
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
sparse_cutlass_supported,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensors24"]
from vllm.platforms import current_platform
class CompressedTensors24(CompressedTensorsScheme):
def __init__(
self,
quantized: bool = False,
weight_quant: QuantizationArgs | None = None,
input_quant: QuantizationArgs | None = None,
model_compression_config: dict[str, Any] | None = None,
):
self.quantized = quantized
self.weight_quant = weight_quant
self.input_quant = input_quant
model_compressor = ModelCompressor.from_compression_config(
model_compression_config
)
self.do_sparse_decompress = (
model_compressor is not None
and model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value
)
if self.do_sparse_decompress:
self.model_compressor = model_compressor
if (
quantized
and input_quant is not None
and self._get_quant_dtype() == current_platform.fp8_dtype()
):
static = not input_quant.dynamic
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.quant_fp8 = QuantFP8(static, g_shape)
@classmethod
def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far
return 90
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
if not sparse_cutlass_supported():
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature"
)
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.do_sparse_decompress:
assert all(
partition_size % 8 == 0 for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape = BasevLLMParameter(
data=torch.empty(2, 1, dtype=torch.int64),
weight_loader=weight_loader,
)
compressed_weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
bitmask = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (
self.weight_quant
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value
):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32
),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert (
self.weight_quant
and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value
)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(
data=torch.ones(1, dtype=torch.float32), requires_grad=False
)
input_scale = torch.nn.Parameter(
data=torch.ones(1, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del layer.compressed
del layer.bitmask
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(
convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
),
requires_grad=False,
)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
# Set all negative zero values to 0 prior to compression
if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2:
layer.weight.data[layer.weight.data == -0.0] = 0.0
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = getattr(layer, "input_scale", None)
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
q_input, input_scale = self.quant_fp8(x, scale=scale)
else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(
a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
)
assert out.is_contiguous()
return out
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
return self._get_quant_dtype()
def _get_quant_dtype(self) -> torch.dtype:
assert self.quantized
assert self.weight_quant is not None
assert self.input_quant is not None
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
if not is_8_bits:
raise ValueError("Cutlass only supports 8-bit quantization")
if (
self.weight_quant.type == QuantizationType.FLOAT
and self.input_quant.type == QuantizationType.FLOAT
):
return torch.float8_e4m3fn
if (
self.weight_quant.type == QuantizationType.INT
and self.input_quant.type == QuantizationType.INT
):
return torch.int8
raise ValueError("Quantization type not supported by Cutlass")
def _decompress_bitmask_compressed_weight(
self,
compressed: torch.Tensor,
bitmask: torch.Tensor,
layer: torch.nn.Module,
) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor = self.model_compressor.sparsity_compressor
def _process_split(
bitmask_compressed_weight: torch.Tensor,
shape,
bitmask: torch.Tensor,
) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: list[torch.Tensor] = []
split_bitmask: list[torch.Tensor] = []
split_shape: list[tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [
(out, layer.input_size_per_partition) for out in layer.logical_widths
]
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask
)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(
compressed=compressed,
shape=(
layer.logical_widths[0],
layer.input_size_per_partition,
),
bitmask=bitmask,
)
)
return decompressed

View File

@@ -0,0 +1,55 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
__all__ = ["CompressedTensorsScheme"]
class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

View File

@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL,
GPTQ_MARLIN_24_MIN_THREAD_N,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
}
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
def __init__(self, strategy: str, num_bits: int, group_size: int | None = None):
self.strategy = strategy
self.group_size = group_size
self.tile_size = 16
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}"
)
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
if self.strategy == "group" and self.group_size is None:
raise ValueError("group_size must be given when using strategy group")
@classmethod
def get_min_capability(cls) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile to be torch.nn.Parameter
layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False)
layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False)
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
assert params_dtype == torch.float16, (
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
)
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader,
)
input_groups = (
1
if self.group_size is None
else input_size_per_partition // self.group_size
)
weight_scale_args = {
"data": torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader": weight_loader,
}
if self.group_size is not None:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)
else:
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
weight_shape = BasevLLMParameter(
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
)
meta = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("scale_packed", scales)
layer.register_parameter("meta", meta)
max_workspace_size = (
output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N
) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(
torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False
)
layer.workspace = workspace
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
qweight = layer.weight_packed
meta = layer.meta
scales = layer.scale_packed
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(
x_2d,
qweight,
meta,
scales,
workspace,
self.quant_type,
size_m,
size_n,
size_k,
)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,124 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
apply_fp4_marlin_linear,
prepare_fp4_layer_for_marlin,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensorsW4A16Fp4"]
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
def __init__(self, has_input_global_scale: bool = False):
self.has_input_global_scale = has_input_global_scale
self.group_size = 16
@classmethod
def get_min_capability(cls) -> int:
# dont restrict as emulations
return 80
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_global_scale", weight_global_scale)
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
if self.has_input_global_scale:
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
# Process parameters for marlin repacking
# Rename weight_packed to weight that marlin expects
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
del layer.weight_packed
# Rename weight_global_scale to weight_scale_2 that marlin expects
# Note: ct stores the inverse of what is expected by the marlin kernel
layer.weight_scale_2 = Parameter(
1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
)
del layer.weight_global_scale
if self.has_input_global_scale:
layer.input_global_scale = torch.nn.Parameter(
layer.input_global_scale.data, requires_grad=False
)
prepare_fp4_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp4_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
weight_scale_2=layer.weight_scale_2,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)

View File

@@ -0,0 +1,218 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
run_nvfp4_emulations,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
cutlass_fp4_supported,
swizzle_blockscale,
)
from vllm.model_executor.parameter import (
GroupQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A4Fp4"]
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
def __init__(self):
self.backend = "none"
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
if has_flashinfer():
self.backend = "flashinfer-cutlass"
elif cutlass_fp4_supported():
self.backend = "cutlass"
elif envs.VLLM_USE_FBGEMM:
self.backend = "fbgemm"
try:
import fbgemm_gpu # noqa: F401
except ImportError as exc:
raise ImportError(
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
"Please install with: pip install fbgemm-gpu-genai"
) from exc
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
self.backend = "cutlass"
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
if self.backend == "none":
raise ValueError(
"No valid NVFP4 GEMM backend found. "
"Please check your platform capability."
)
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
self.group_size = 16
@classmethod
def get_min_capability(cls) -> int:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
return 80
return 100
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Weight
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
# Global Weight Scale
weight_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_global_scale", weight_global_scale)
# Per Group Weight Scale
weight_scale = GroupQuantScaleParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // self.group_size,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
input_global_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_global_scale", input_global_scale)
def process_weights_after_loading(self, layer) -> None:
global_input_scale = layer.input_global_scale.max().to(torch.float32)
layer.input_global_scale = Parameter(global_input_scale, requires_grad=False)
layer.weight_global_scale = Parameter(
layer.weight_global_scale.max().to(torch.float32), requires_grad=False
)
if self.backend == "flashinfer-trtllm":
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
# layout but we use our own quantization so we have to call
# shuffles ourselves.
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
weight = layer.weight_packed.data
weight_scale = layer.weight_scale.data
epilogue_tile_m = 128
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
weight_scale = (
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
.reshape(weight_scale.shape)
.view(torch.float8_e4m3fn)
)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.weight_packed = Parameter(weight, requires_grad=False)
else:
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
if self.backend == "fbgemm":
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
layer.weight_packed = Parameter(
layer.weight_packed.data, requires_grad=False
)
layer.alpha = Parameter(
1 / (layer.input_global_scale * layer.weight_global_scale),
requires_grad=False,
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
out = run_nvfp4_emulations(
x=x,
input_global_scale=layer.input_global_scale,
weight=layer.weight_packed,
weight_scale_swizzled=layer.weight_scale,
weight_global_scale=layer.weight_global_scale,
)
if bias is not None:
out = out + bias
return out
output_dtype = x.dtype
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
mm_args = (
x_fp4,
layer.weight_packed,
x_blockscale,
layer.weight_scale,
layer.alpha,
output_dtype,
)
if self.backend.startswith("flashinfer-"):
backend_name = self.backend[len("flashinfer-") :]
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
elif self.backend == "fbgemm":
out = torch.ops.fbgemm.f4f4bf16(
x_fp4,
layer.weight_packed,
x_blockscale.view(-1).view(torch.uint8),
layer.weight_scale,
layer.alpha,
use_mx=False,
).to(output_dtype)
else:
assert self.backend == "cutlass"
out = cutlass_scaled_fp4_mm(*mm_args)
if bias is not None:
out = out + bias
return out.view(*output_shape)

View File

@@ -0,0 +1,183 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Fp8"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self,
strategy: str,
num_bits: int,
group_size: int | None = None,
symmetric: bool | None = True,
actorder: ActivationOrdering | None = None,
):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size != 128 or self.strategy != "group":
raise ValueError(
"W4A8 kernels require group quantization with group size 128"
)
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}"
)
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
# hopper
return 90
def create_weights(
self,
layer: torch.nn.Module,
output_size: int,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_type,
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx,
out_type=params_dtype,
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Fp8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = input_size != input_size_per_partition
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel
)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
)
# TODO(czhu): allocate the packed fp8 scales memory here?
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
weight_scale_args = {
"weight_loader": weight_loader,
"data": torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=torch.float8_e4m3fn,
),
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(
output_dim=0, input_dim=1, **weight_scale_args
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
)
# per-channel scales
weight_chan_scale = ChannelQuantScaleParameter(
data=torch.empty((output_size_per_partition, 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("weight_chan_scale", weight_chan_scale)
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx",
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsW4A8Int"]
W4A8_SUPPORTED_TYPES_MAP = {
4: scalar_types.int4,
}
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self,
strategy: str,
num_bits: int,
group_size: int | None = None,
is_static_input_scheme: bool = False,
input_symmetric: bool = True,
):
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}."
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}"
)
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
return 1
def create_weights(
self,
layer: torch.nn.Module,
output_size: int,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
row_parallel = input_size != input_size_per_partition
# Compute effective group_size
if self.group_size == -1:
effective_group_size = (
input_size_per_partition if row_parallel else input_size
)
else:
effective_group_size = self.group_size
# Ensure group_size divides input_size_per_partition
assert input_size_per_partition % effective_group_size == 0, (
f"input_size_per_partition {input_size_per_partition}"
f" not divisible by group_size {effective_group_size}"
)
# Determine scale partitioning
is_channelwise = self.group_size == -1
repeat_scales = is_channelwise and row_parallel
partition_scales = not repeat_scales
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=effective_group_size,
zero_points=False,
has_g_idx=False,
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW4A8Int", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
scales_and_zp_size = input_size_per_partition // effective_group_size
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition, input_size_per_partition, dtype=torch.int8
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
weight_scale_args = {
"weight_loader": weight_loader,
"data": torch.empty(
output_size_per_partition, scales_and_zp_size, dtype=params_dtype
),
}
if partition_scales:
weight_scale = GroupQuantScaleParameter(
output_dim=0, input_dim=1, **weight_scale_args
)
else:
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name=None,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,138 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
prepare_fp8_layer_for_marlin(layer)
def create_weights(
self,
layer: torch.nn.Module,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
layer.weight_block_size = None
# WEIGHT
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}"
)
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)

View File

@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from torch.nn import Parameter
from vllm._aiter_ops import rocm_aiter_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
create_fp8_input_scale,
create_fp8_scale_parameter,
create_fp8_weight_parameter,
maybe_post_process_fp8_weight_block,
process_fp8_weight_block_strategy,
process_fp8_weight_channel_strategy,
process_fp8_weight_tensor_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported,
maybe_create_device_identity,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensorsW8A8Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
if self.weight_block_size is not None:
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
else:
self.act_q_group_shape = (
GroupShape.PER_TENSOR
if is_static_input_scheme
else GroupShape.PER_TOKEN
)
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
if self.weight_block_size is not None:
assert not self.is_static_input_scheme
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
weight_group_shape=GroupShape(*self.weight_block_size),
act_quant_group_shape=self.act_q_group_shape,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
)
else:
self.fp8_linear = Fp8LinearOp(
act_quant_static=self.is_static_input_scheme,
act_quant_group_shape=self.act_q_group_shape,
)
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
maybe_create_device_identity()
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.weight_block_size = None
layer.orig_dtype = params_dtype
if self.strategy == QuantizationStrategy.BLOCK:
assert self.weight_block_size is not None
layer.weight_block_size = self.weight_block_size
# Validate block quantization shapes
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition,
output_partition_sizes,
self.weight_block_size,
)
# WEIGHT
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy],
output_partition_sizes,
input_size_per_partition,
layer.weight_block_size,
weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
layer.register_parameter("input_scale", input_scale)
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
layer.weight,
layer.weight_scale,
layer.logical_widths,
getattr(layer, "input_scale", None),
)
weight = weight.t()
elif self.strategy == QuantizationStrategy.CHANNEL:
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
)
weight = weight.t()
elif self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
weight, weight_scale = process_fp8_weight_block_strategy(
layer.weight, layer.weight_scale
)
input_scale = None
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# required by torch.compile to be torch.nn.Parameter
layer.weight = Parameter(weight.data, requires_grad=False)
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
if input_scale is not None:
layer.input_scale = Parameter(input_scale.data, requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
if self.strategy == QuantizationStrategy.BLOCK:
maybe_post_process_fp8_weight_block(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.weight_block_size is not None:
return self.w8a8_block_fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
return self.fp8_linear.apply(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)

View File

@@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig,
choose_scaled_mm_linear_kernel,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
):
import vllm.envs as env
if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
self.strategy = QuantizationStrategy.TENSOR
elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
self.strategy = QuantizationStrategy.CHANNEL
else:
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def create_weights(
self,
layer: torch.nn.Module,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
is_static_input_scheme=self.is_static_input_scheme,
input_symmetric=self.input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
remainder = input_size_per_partition % 64
if remainder != 0:
input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
else:
input_size_per_partition_padded = input_size_per_partition
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition_padded,
dtype=torch.int8),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
)
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type(
c=scaled_mm_linear_kernel_config,
w_q_param_name="weight",
w_s_param_name="weight_scale",
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch
from compressed_tensors.quantization import ActivationOrdering
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: set[str] = set()
def __init__(
self,
strategy: str,
num_bits: int,
group_size: int | None = None,
symmetric: bool | None = True,
actorder: ActivationOrdering | None = None,
):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.symmetric = symmetric
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel":
raise ValueError(
"Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise."
)
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}"
)
self.quant_type = (
WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
if not self.symmetric
else WNA16_SUPPORTED_TYPES_MAP[num_bits]
)
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
def create_weights(
self,
layer: torch.nn.Module,
output_size: int,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int,
params_dtype: torch.dtype,
weight_loader: Callable,
**kwargs,
):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=not self.symmetric,
has_g_idx=self.has_g_idx,
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = input_size != input_size_per_partition
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel
)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
)
weight_scale_args = {
"weight_loader": weight_loader,
"data": torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
}
zeros_args = {
"weight_loader": weight_loader,
"data": torch.zeros(
output_size_per_partition // self.pack_factor,
scales_and_zp_size,
dtype=torch.int32,
),
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
if not self.symmetric:
qzeros = PackedColumnParameter(
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args,
)
else:
weight_scale = GroupQuantScaleParameter(
output_dim=0, input_dim=1, **weight_scale_args
)
if not self.symmetric:
qzeros = PackedvLLMParameter(
input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.pack_factor,
**zeros_args,
)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
if not self.symmetric:
layer.register_parameter("weight_zero_point", qzeros)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight_g_idx", weight_g_idx)
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name="weight_zero_point",
w_gidx_param_name="weight_g_idx",
)
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,260 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable, Generator
from itertools import accumulate
import torch
from compressed_tensors.transform import (
TransformArgs,
TransformConfig,
TransformLocation,
TransformScheme,
)
from compressed_tensors.utils import is_match
from vllm.model_executor.layers.linear import (
WEIGHT_LOADER_V2_SUPPORTED,
LinearMethodBase,
)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
HadamardTransform,
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple,
)
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
"""
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
input and output transforms to either side of the original apply method
"""
@classmethod
def from_schemes(
cls,
quant_method: LinearMethodBase,
quant_scheme: CompressedTensorsScheme | None,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
) -> "CompressedTensorsLinearTransformMethod":
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
QutlassNvFP4LinearMethod,
is_qutlass_fp4_scheme,
)
assert input_tfms or output_tfms
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
return QutlassNvFP4LinearMethod(quant_method, input_tfms, output_tfms)
# hadacore or dense gemm is selected by Transform module
return cls(quant_method, input_tfms, output_tfms)
def __init__(
self,
quant_method: LinearMethodBase,
input_tfms: dict[int, TransformTuple],
output_tfms: dict[int, TransformTuple],
):
self.quant_method = quant_method
self.input_tfms = input_tfms
self.output_tfms = output_tfms
self.input_transform: HadamardTransform | None = None
self.output_transform: HadamardTransform | None = None
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
# get weight loader for transforms
weight_loader: Callable = extra_weight_attrs.get("weight_loader") # type: ignore[assignment]
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
# transforms (specifically SharedWeightParameter) requires
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
# hack around this by getting weight loader v1 so ULM can load correctly
quant_method_name = self.quant_method.__class__.__name__
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
weight_loader_v1 = layer.weight_loader
extra_weight_attrs["weight_loader"] = weight_loader_v1
self.quant_method.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size,
params_dtype=params_dtype,
**extra_weight_attrs,
)
# validate schemes
num_partitions = len(output_partition_sizes)
self._validate_tfm_schemes(num_partitions)
# create submodules for weight loading
if len(self.input_tfms) > 0:
scheme_name = list(self.input_tfms.values())[0].scheme_name
location = list(self.input_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(
self.input_tfms,
layer,
weight_loader,
input_size_per_partition,
output_partition_sizes,
)
layer.register_module(transform_name, transform)
self.input_transform = transform
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
transform_name = f"{scheme_name}_{location}"
transform = HadamardTransform(
self.output_tfms,
layer,
weight_loader,
input_size_per_partition,
output_partition_sizes,
)
layer.register_module(transform_name, transform)
self.output_transform = transform
# compute partition ranges for slicing activations
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
self.partition_ranges = list(zip(starts, output_partition_sizes))
def process_weights_after_loading(self, layer):
self.quant_method.process_weights_after_loading(layer)
for submodule in layer.children():
if isinstance(submodule, HadamardTransform):
submodule.process_weights_after_loading()
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if self.input_transform is not None:
x = self.input_transform(x)
assert bias is None
x = self.quant_method.apply(layer, x, bias)
# In most cases, input transforms are preferred over output transforms
# (@ksayers): confirm that this is done concurrently
if self.output_transform is not None:
for part_id, (start, length) in enumerate(self.partition_ranges):
x[:, start : start + length] = self.output_transform(
x[:, start : start + length].clone(), part_id=part_id
)
return x
def _validate_tfm_schemes(self, num_partitions: int):
if len(self.input_tfms) > 0:
if 0 not in self.input_tfms:
raise ValueError("Must have same input")
for part_index in range(num_partitions):
if self.input_tfms[part_index] != self.input_tfms[0]:
raise ValueError("Must have same input")
if len(self.output_tfms) > 0:
scheme_name = list(self.output_tfms.values())[0].scheme_name
location = list(self.output_tfms.values())[0].args.location
for tfm in self.output_tfms.values():
if tfm.scheme_name != scheme_name:
raise ValueError("Must have same scheme name")
if tfm.args.location != location:
raise ValueError("Must have same location")
return self.input_tfms, self.output_tfms
def get_linear_transform_schemes(
layer: torch.nn.Module,
layer_name: str,
transform_config: TransformConfig | None,
packed_modules_mapping: dict[str, list[str]],
) -> tuple[
dict[int, TransformTuple], dict[int, TransformTuple]
]: # [input_transform, [output_transform, ...]]
# there can only be one transform input scheme per (fused) module
input_tfms = {}
output_tfms = {}
partition_names = get_layer_partition_names(layer_name, packed_modules_mapping)
for scheme_name, scheme, args in get_schemes_args(transform_config):
for part_index, part_name in enumerate(partition_names):
if (
is_match(part_name, layer, args.targets, args.ignore)
and args.is_online()
):
if args.location == TransformLocation.INPUT:
input_tfms[part_index] = TransformTuple(scheme_name, scheme, args)
elif args.location == TransformLocation.OUTPUT:
output_tfms[part_index] = TransformTuple(scheme_name, scheme, args)
else:
raise ValueError(
f"Cannot apply `{args.location}` transform to `{layer_name}`"
)
return (input_tfms, output_tfms)
def get_schemes_args(
transform_config: TransformConfig | None,
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
if transform_config is None:
return
for scheme_name, scheme in transform_config.config_groups.items():
for args in scheme.apply:
yield (scheme_name, scheme, args)
def get_layer_partition_names(
layer_name: str, packed_modules_mapping: dict[str, list[str]]
) -> list[str]:
"""
Get all partition names associated with this layer.
Names are returned in order of their partition indices.
```python
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
assert get_layer_partition_names("mlp.gate_up_proj", mapping) == [
"gate_proj",
"up_proj",
]
assert get_layer_partition_names("mlp.down_proj", mapping) == ["down_proj"]"""
for fused_suffix, part_suffixes in packed_modules_mapping.items():
if layer_name.endswith(fused_suffix):
return [
layer_name.removesuffix(fused_suffix) + part_suffix
for part_suffix in part_suffixes
]
return [layer_name]

View File

@@ -0,0 +1,173 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Callable, Hashable
import torch
from compressed_tensors.transform import (
TransformArgs,
TransformLocation,
TransformScheme,
)
from torch import Tensor
import vllm._custom_ops as ops
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
TransformTuple,
)
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.parameter import SharedWeightParameter
class HadamardTransform(torch.nn.Module):
"""
Class which handles weight loading, postprocessing, and application of
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
and attention transforms method (not implemented yet)
"""
transforms: dict[int, TransformTuple] # info parsed from transforms config
weight: SharedWeightParameter # container for shared tensors
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
def __init__(
self,
transforms: dict[int, TransformTuple],
layer: torch.nn.Module,
weight_loader: Callable,
input_size_per_partition: int,
output_partition_sizes: list[int],
):
super().__init__()
self.transforms = transforms
self.scales = {}
if get_tensor_model_parallel_world_size() > 1:
raise NotImplementedError(
"Online transforms with tensor parallelism is not supported"
)
# Similar to row/col parallel params, but tensors are separate
# to allow for loading with shared memory
self.weight = SharedWeightParameter(weight_loader=weight_loader)
# create shared partition data for each partition of the original weight
input_size = input_size_per_partition
for part_index, (_scheme_name, scheme, args) in self.transforms.items():
output_size = output_partition_sizes[part_index]
weight_size = self._get_weight_size(
layer, scheme, args, input_size, output_size
)
data_key = self._get_data_key(scheme, weight_size)
self.weight.add_partition(
part_index,
data_key,
size=(weight_size, weight_size),
dtype=scheme.precision,
)
# validate that shared tensors and schemes are correct
self._validate_input_transforms()
def process_weights_after_loading(self):
for part_id in self.weight.partitions:
data = self.weight.partitions[part_id].data
# required by torch.compile
self.weight.process_weights_after_loading()
# precompute scale as a runtime multiply, not division
# do not fold into weight in order to utilize FWHT
self.scales[part_id] = 1 / math.sqrt(data.size(0))
# FUTURE: avoid runtime transpose by processing weights
# prior to apply
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
if part_id not in self.weight.partitions:
return value
# use hadacore if possible
if self.transforms[part_id].scheme.type == "hadamard":
if self.transforms[part_id].scheme.head_dim is not None:
weight_size = self.transforms[part_id].scheme.head_dim
value = value.unflatten(-1, (-1, weight_size))
value = ops.hadacore_transform(value)
value = value.flatten(-2, -1)
return value
# sylvester transforms are symmetric, inv => transpose => original
return ops.hadacore_transform(value)
# fall back to dense
else:
weight = self.weight.partitions[part_id]
weight = (
weight if self.transforms[part_id].args.inverse else weight.T
) # linear := x(W.T)
scale = self.scales[part_id]
if self.transforms[part_id].scheme.head_dim is not None:
value = value.unflatten(-1, (-1, weight.size(0)))
value = (
dispatch_unquantized_gemm()(
self, value.to(weight.dtype), weight, None
).to(value.dtype)
* scale
)
value = value.flatten(-2, -1)
return value
return (
dispatch_unquantized_gemm()(
self, value.to(weight.dtype), weight, None
).to(value.dtype)
* scale
)
def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable:
return (id(scheme), weight_size)
def _get_weight_size(
self,
layer: torch.nn.Module,
scheme: TransformScheme,
args: TransformArgs,
input_size: int,
output_size: int,
) -> int:
if scheme.head_dim is not None:
return scheme.head_dim
if isinstance(layer, LinearBase):
if args.location == TransformLocation.INPUT:
return input_size
elif args.location == TransformLocation.OUTPUT:
return output_size
elif isinstance(layer, VocabParallelEmbedding):
if args.location == TransformLocation.INPUT:
return output_size
elif args.location == TransformLocation.OUTPUT:
return input_size
raise ValueError()
def _validate_input_transforms(self):
assert len(self.transforms) > 0
location = list(self.transforms.values())[0].args.location
if location == TransformLocation.INPUT:
first_data = self.weight.partitions[0].data
for partition in self.weight.partitions.values():
if partition.data.data_ptr() != first_data.data_ptr():
raise ValueError("")

View File

@@ -0,0 +1,64 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsScheme,
CompressedTensorsW4A4Fp4,
)
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
CompressedTensorsLinearTransformMethod,
TransformTuple,
)
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
def is_qutlass_fp4_scheme(
quant_scheme: CompressedTensorsScheme | None,
input_tfms: dict[int, TransformTuple],
) -> bool:
return (
isinstance(quant_scheme, (CompressedTensorsW4A4Fp4,))
and len(input_tfms) == 1
and input_tfms[0].scheme.head_dim == quant_scheme.group_size
)
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
def create_weights(
self,
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
):
# initializes fp4 qparams
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4,))
ret = super().create_weights(
layer,
input_size_per_partition,
output_partition_sizes,
input_size,
output_size,
params_dtype,
**extra_weight_attrs,
)
assert self.input_transform is not None
assert len(self.input_transform.weight) == 1
assert self.input_transform.weight[0].size(0) == layer.scheme.group_size
return ret
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError()

View File

@@ -0,0 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import NamedTuple
from compressed_tensors.transform import TransformArgs, TransformScheme
__all__ = ["TransformTuple"]
class TransformTuple(NamedTuple):
scheme_name: str
scheme: TransformScheme
args: TransformArgs

View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
def is_weak_contiguous(x: torch.Tensor):
strides = x.stride()
sizes = x.shape
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
return is_transpose or is_not_transpose
@triton.jit
def scaled_mm_kernel(
a_ptr,
b_ptr,
scale_a_ptr,
scale_b_ptr,
c_ptr,
bias_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
ACCUMULATOR_DTYPE: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_SCALE_A: tl.constexpr,
BLOCK_SIZE_SCALE_B: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
accumulator_dtype = ACCUMULATOR_DTYPE
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
# eventually occur.
# Offsets and masks.
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
masks_am = offsets_am < M
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
masks_bn = offsets_bn < N
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
# appropriate offsets and masks for each case. Same goes for
# BLOCK_SIZE_SCALE_B.
offsets_scale_am = (
tl.arange(0, BLOCK_SIZE_SCALE_A)
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
)
masks_scale_am = offsets_scale_am < M
offsets_scale_bn = (
tl.arange(0, BLOCK_SIZE_SCALE_B)
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
)
masks_scale_bn = offsets_scale_bn < N
a_ptrs = a_ptr + offsets_a
b_ptrs = b_ptr + offsets_b
scale_a_ptrs = scale_a_ptr + offsets_scale_am
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
masks_k = offsets_k < K
masks_a = masks_am[:, None] & masks_k[None, :]
a = tl.load(a_ptrs, mask=masks_a)
masks_b = masks_k[:, None] & masks_bn[None, :]
b = tl.load(b_ptrs, mask=masks_b)
# Accumulate results.
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
offsets_k += BLOCK_SIZE_K
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
# Apply scale at end.
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
# Need to broadcast to the appropriate size, if scale_a is already
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
# for scale_b below.
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
accumulator = scale_a * accumulator.to(tl.float32)
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
accumulator = scale_b.T * accumulator.to(tl.float32)
# Convert to output format.
c = accumulator.to(c_ptr.type.element_ty)
# Add bias, it's already in output format, so add it after conversion.
if bias_ptr:
offsets_bias = offsets_bn
bias_ptrs = bias_ptr + offsets_bias
bias_mask = offsets_bias < N
bias = tl.load(bias_ptrs, bias_mask)
c += bias
# Save output
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
offs_cm = offs_cm.to(tl.int64)
offs_cn = offs_cn.to(tl.int64)
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
# input - [M, K]
# weight - [K, N]
def triton_scaled_mm(
input: torch.Tensor,
weight: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: type[torch.dtype],
bias: torch.Tensor | None = None,
block_size_m: int = 32,
block_size_n: int = 32,
block_size_k: int = 32,
use_heuristic=True,
) -> torch.Tensor:
M, K = input.shape
N = weight.shape[1]
assert N > 0 and K > 0 and M > 0
assert weight.shape[0] == K
assert input.dtype == weight.dtype
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
assert out_dtype.is_floating_point
assert bias is None or bias.is_floating_point()
assert is_weak_contiguous(input)
assert is_weak_contiguous(weight)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
if use_heuristic:
is_small_N = N < 8192
next_power_of_2_M = max(32, triton.next_power_of_2(M))
if next_power_of_2_M <= 32:
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
elif next_power_of_2_M <= 64:
tile_shape = (64, 64, 256)
elif next_power_of_2_M <= 128:
tile_shape = (64, 128, 128)
else:
tile_shape = (128, 128, 128)
block_size_m, block_size_n, block_size_k = tile_shape
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
# A = input, B = weight, C = result
# A = M x K, B = K x N, C = M x N
scaled_mm_kernel[grid](
input,
weight,
scale_a,
scale_b,
result,
bias,
M,
N,
K,
input.stride(0),
input.stride(1),
weight.stride(0),
weight.stride(1),
result.stride(0),
result.stride(1),
accumulator_dtype,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
BLOCK_SIZE_SCALE_A=block_size_sa,
BLOCK_SIZE_SCALE_B=block_size_sb,
)
return result.to(out_dtype)

View File

@@ -0,0 +1,216 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping
from types import MappingProxyType
import regex as re
from compressed_tensors import CompressionFormat
from torch.nn import Module
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value,
CompressionFormat.nvfp4_pack_quantized.value,
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(
layer_name: str | None,
ignore: Iterable[str] = tuple(),
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in fused_mapping and layer_name not in ignore:
shard_proj_names = fused_mapping[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore
)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(
f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme."
)
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(
layer_name=layer_name, targets=ignore
)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
def find_matched_target(
layer_name: str | None,
module: Module,
targets: Iterable[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
Third, we try to map the layer_name to a list of fused module names.
*All* component module names must match in order for a match to be
successful. A successful match returns the first component target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
:param fused_strategy: either "all" or "any". If using "all", fused
layers match if "all" of its components match
"""
if layer_name is None:
layer_name = ""
matched_target = (
_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets, True)
or _match_fused_layer(layer_name, targets, fused_mapping)
)
if matched_target is None:
raise ValueError(
f"Unable to find matching target for {layer_name} in the "
"compressed-tensors config."
)
return matched_target
def _find_first_match(
value: str, targets: Iterable[str], check_contains: bool = False
) -> str | None:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
for target in targets:
if _is_equal_or_regex_match(value, target, check_contains=check_contains):
return target
return None
def _is_equal_or_regex_match(
value: str, target: str, check_contains: bool = False
) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False
def _match_fused_layer(
layer_name: str,
target_layers: Iterable[str],
fused_mapping: Mapping[str, list[str]],
) -> str | None:
"""
Match a fused layer name to its corresponding individual layer in
target_layers. Returns first value in fused_mapping which matches targets
Implements an "all" matching strategy where a fused layer matches iff
"all" of its components match
:param layer_name: layer name
:param target_layers: list of targets to match the layer against
:param fused_mapping: map from fused layer names to its components
Examples:
layer_name = "model.layers.0.self_attn.qkv_proj"
target_layers = ["model.layers.0.self_attn.q_proj",
"model.layers.0.self_attn.k_proj",
"model.layers.0.self_attn.v_proj"]
"""
# find layer_name in mapping
fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
if fused is None:
return None
# expand path of unfused components
unfused_paths = [
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
]
# for each unfused component, find a match in targets
unfused_matches: list[str | None] = []
for unfused in unfused_paths:
for target in target_layers:
if _is_equal_or_regex_match(unfused, target):
unfused_matches.append(target)
break
else:
unfused_matches.append(None)
return unfused_matches[0] if all(unfused_matches) else None