v1.0
This commit is contained in:
@@ -0,0 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,914 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import suppress
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
|
||||
|
||||
import torch
|
||||
from compressed_tensors.config import (
|
||||
CompressionFormat,
|
||||
SparsityCompressionConfig,
|
||||
SparsityStructure,
|
||||
)
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType,
|
||||
)
|
||||
from compressed_tensors.transform import TransformConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_BITS,
|
||||
CompressedTensors24,
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW4A4Fp4,
|
||||
CompressedTensorsW4A8Fp8,
|
||||
CompressedTensorsW4A8Int,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod,
|
||||
get_linear_transform_schemes,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
is_activation_quantization_format,
|
||||
should_ignore_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
|
||||
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
||||
QUANTIZATION_SCHEME_MAP_TYPE = dict[str, dict[str, QuantizationArgs] | None]
|
||||
|
||||
|
||||
class CompressedTensorsConfig(QuantizationConfig):
|
||||
def __init__(
|
||||
self,
|
||||
target_scheme_map: dict[str, Any],
|
||||
ignore: list[str],
|
||||
quant_format: str,
|
||||
sparsity_scheme_map: dict[str, SparsityCompressionConfig],
|
||||
sparsity_ignore_list: list[str],
|
||||
kv_cache_scheme: dict[str, Any] | None = None,
|
||||
config: dict[str, Any] | None = None,
|
||||
transform_config: dict[str, Any] | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.ignore = ignore
|
||||
self.quant_format = quant_format
|
||||
# Map from [target -> scheme]
|
||||
self.target_scheme_map = target_scheme_map
|
||||
self.kv_cache_scheme = kv_cache_scheme
|
||||
self.sparsity_scheme_map = sparsity_scheme_map
|
||||
self.sparsity_ignore_list = sparsity_ignore_list
|
||||
self.config = config
|
||||
|
||||
if transform_config:
|
||||
self.transform_config = TransformConfig.model_validate(transform_config)
|
||||
else:
|
||||
self.transform_config = None
|
||||
|
||||
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
||||
return CompressedTensorsLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float32, torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "compressed-tensors"
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(self.target_scheme_map)
|
||||
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
|
||||
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
|
||||
self.sparsity_scheme_map
|
||||
)
|
||||
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
|
||||
self.sparsity_ignore_list
|
||||
)
|
||||
if self.kv_cache_scheme is not None:
|
||||
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(self.kv_cache_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
# collect schemes
|
||||
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
input_tfms, output_tfms = get_linear_transform_schemes(
|
||||
layer, prefix, self.transform_config, self.packed_modules_mapping
|
||||
)
|
||||
|
||||
# choose quantization method
|
||||
quant_method: LinearMethodBase = UnquantizedLinearMethod()
|
||||
if quant_scheme is not None:
|
||||
layer.scheme = quant_scheme
|
||||
quant_method = CompressedTensorsLinearMethod(self)
|
||||
|
||||
# choose transform method
|
||||
if any((input_tfms, output_tfms)):
|
||||
return CompressedTensorsLinearTransformMethod.from_schemes(
|
||||
quant_method, quant_scheme, input_tfms, output_tfms
|
||||
)
|
||||
|
||||
else:
|
||||
return quant_method
|
||||
|
||||
if isinstance(layer, Attention):
|
||||
return CompressedTensorsKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return CompressedTensorsMoEMethod.get_moe_method(self, layer)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "CompressedTensorsConfig":
|
||||
ignore: list[str] = cast(list[str], config.get("ignore", []))
|
||||
quant_format = cast(str, config.get("format"))
|
||||
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
||||
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
||||
config=config
|
||||
)
|
||||
transform_config = config.get("transform_config")
|
||||
|
||||
return cls(
|
||||
target_scheme_map=target_scheme_map,
|
||||
ignore=ignore,
|
||||
quant_format=quant_format,
|
||||
sparsity_scheme_map=sparsity_scheme_map,
|
||||
sparsity_ignore_list=sparsity_ignore_list,
|
||||
config=config,
|
||||
transform_config=transform_config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _parse_sparsity_config(
|
||||
cls, config: dict[str, Any]
|
||||
) -> tuple[dict[str, SparsityCompressionConfig], list[str]]:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A tuple with two elements
|
||||
1. A dictionary mapping target layer names to their corresponding
|
||||
sparsity_config
|
||||
2. A list of layer names to ignore for sparsity
|
||||
"""
|
||||
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
||||
return dict(), []
|
||||
|
||||
sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)
|
||||
sparse_scheme_map: dict[str, SparsityCompressionConfig] = {
|
||||
target: sparsity_config for target in sparsity_config.targets or list()
|
||||
}
|
||||
sparsity_ignore_list = sparsity_config.ignore or list()
|
||||
return sparse_scheme_map, sparsity_ignore_list
|
||||
|
||||
@classmethod
|
||||
def _quantization_scheme_map_from_config(
|
||||
cls, config: dict[str, Any]
|
||||
) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
||||
"""
|
||||
:param config: The `quantization_config` dictionary from config.json
|
||||
:return: A dictionary mapping target layer names to their corresponding
|
||||
quantization_args for weights and input activations
|
||||
"""
|
||||
target_scheme_map: dict[str, Any] = dict()
|
||||
quant_format = cast(str, config.get("format"))
|
||||
|
||||
# The quant_config has multiple config_groups, each containing
|
||||
# an input_activations key with details about how the activations are
|
||||
# quantized, a weights key indicating how the weights are quantized,
|
||||
# and a list of targets under the `targets` key, dictating which
|
||||
# layers are impacted by the quantization details. The quantization
|
||||
# details follow the structure defined by the QuantizationArgs
|
||||
# pydantic model, which is used to verify the structure of the
|
||||
# quant_config and also store the details for later use.
|
||||
|
||||
config_groups = config.get("config_groups", dict())
|
||||
for _, quant_config in config_groups.items():
|
||||
targets = quant_config.get("targets")
|
||||
for target in targets:
|
||||
target_scheme_map[target] = {}
|
||||
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
|
||||
quant_config.get("weights")
|
||||
)
|
||||
|
||||
target_scheme_map[target]["input_activations"] = None
|
||||
target_scheme_map[target]["format"] = quant_config.get("format")
|
||||
format = target_scheme_map[target].get("format")
|
||||
# If no per-config format defined, use global format in config
|
||||
act_quant_format = (
|
||||
is_activation_quantization_format(format)
|
||||
if format is not None
|
||||
else is_activation_quantization_format(quant_format)
|
||||
)
|
||||
# TODO(czhu): w4a8fp8 is in packed-quantized format
|
||||
# but needs input activation quantization
|
||||
input_activations = quant_config.get("input_activations")
|
||||
if act_quant_format or input_activations:
|
||||
# The only case where we have activation quant supported
|
||||
# but no input_activations provided in the config
|
||||
# should be w8a16fp8 w8a16fp8 can also run for cases where
|
||||
# there is an input_quant but it is ignored
|
||||
if not input_activations:
|
||||
assert (
|
||||
target_scheme_map[target]["weights"].type
|
||||
== QuantizationType.FLOAT
|
||||
)
|
||||
else:
|
||||
target_scheme_map[target]["input_activations"] = (
|
||||
QuantizationArgs.model_validate(
|
||||
quant_config.get("input_activations")
|
||||
)
|
||||
)
|
||||
return target_scheme_map
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(
|
||||
self, min_capability: int, error: bool = True, match_exact: bool = False
|
||||
) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
if match_exact:
|
||||
supported = capability == min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
"the current GPU. Required capability: ",
|
||||
f"{min_capability}. Current capability: {capability}.",
|
||||
)
|
||||
else:
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.",
|
||||
)
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp4a4_nvfp4(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
):
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_tensor_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
and input_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
)
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
|
||||
is_group_size_16 = (
|
||||
weight_quant.group_size == 16 and input_quant.group_size == 16
|
||||
)
|
||||
is_float_type = (
|
||||
weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT
|
||||
)
|
||||
is_4_bits = weight_quant.num_bits == 4 and input_quant.num_bits == 4
|
||||
|
||||
return (
|
||||
is_tensor_group_quant
|
||||
and is_float_type
|
||||
and is_4_bits
|
||||
and is_group_size_16
|
||||
and is_symmetric
|
||||
)
|
||||
|
||||
def _is_fp4a16_nvfp4(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
):
|
||||
is_weight_only = weight_quant is not None and input_quant is None
|
||||
is_tensor_group_quant = (
|
||||
weight_quant.strategy == QuantizationStrategy.TENSOR_GROUP.value
|
||||
)
|
||||
is_symmetric = weight_quant.symmetric
|
||||
|
||||
is_group_size_16 = weight_quant.group_size == 16
|
||||
is_float_type = weight_quant.type == QuantizationType.FLOAT
|
||||
is_4_bits = weight_quant.num_bits == 4
|
||||
|
||||
return (
|
||||
is_weight_only
|
||||
and is_tensor_group_quant
|
||||
and is_float_type
|
||||
and is_4_bits
|
||||
and is_group_size_16
|
||||
and is_symmetric
|
||||
)
|
||||
|
||||
def _is_static_tensor_w8a8(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_tensor = (
|
||||
weight_strategy
|
||||
and input_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
)
|
||||
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
||||
|
||||
def _is_dynamic_token_w8a8(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_token = (
|
||||
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
||||
)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
||||
|
||||
def _is_dynamic_token_w4a8_int(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = (
|
||||
weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
)
|
||||
is_token = (
|
||||
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
||||
)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return (
|
||||
is_weight_4_bits
|
||||
and is_activation_8_bits
|
||||
and is_token
|
||||
and weight_quant.symmetric
|
||||
and is_dynamic
|
||||
)
|
||||
|
||||
def _is_fp8_w8a8(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
# Confirm weights and activations quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_floating_point = (
|
||||
weight_quant.type == QuantizationType.FLOAT
|
||||
and input_quant.type == QuantizationType.FLOAT
|
||||
)
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_tensor_or_channel_or_block_weight = weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK,
|
||||
]
|
||||
if not (
|
||||
is_floating_point
|
||||
and is_symmetric_weight
|
||||
and is_static_weight
|
||||
and is_tensor_or_channel_or_block_weight
|
||||
):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.dynamic:
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_symmetric_activation = input_quant.symmetric
|
||||
is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
|
||||
return is_symmetric_activation and is_per_tensor_activation
|
||||
|
||||
def _is_fp8_w4a8(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
if not weight_quant or not input_quant:
|
||||
return False
|
||||
is_weight_4_bits = weight_quant.num_bits == 4
|
||||
is_activation_8_bits = input_quant.num_bits == 8
|
||||
weight_strategy = weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
is_token = (
|
||||
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
||||
)
|
||||
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
||||
is_symmetric = weight_quant.symmetric and input_quant.symmetric
|
||||
# Only per-group symmetric weight (4bit)
|
||||
# + per-tok symmetric activation (8bit) quantization supported.
|
||||
return (
|
||||
is_weight_4_bits
|
||||
and is_activation_8_bits
|
||||
and is_token
|
||||
and is_symmetric
|
||||
and is_dynamic
|
||||
)
|
||||
|
||||
def _is_fp8_w4a8_sm90(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
return self._check_scheme_supported(
|
||||
90, error=False, match_exact=True
|
||||
) and self._is_fp8_w4a8(weight_quant, input_quant)
|
||||
|
||||
def _is_fp8_w8a8_sm90(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
return self._check_scheme_supported(
|
||||
90, error=False, match_exact=True
|
||||
) and self._is_fp8_w8a8(weight_quant, input_quant)
|
||||
|
||||
def _is_fp8_w8a8_sm100(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
return self._check_scheme_supported(
|
||||
100, error=False, match_exact=True
|
||||
) and self._is_fp8_w8a8(weight_quant, input_quant)
|
||||
|
||||
def _is_fp8_w8a16(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
# Confirm weights quantized.
|
||||
if weight_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm we have floating points.
|
||||
if weight_quant.type != QuantizationType.FLOAT:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported.
|
||||
is_symmetric_weight = weight_quant.symmetric
|
||||
is_static_weight = not weight_quant.dynamic
|
||||
is_tensor_or_channel_or_block_weight = weight_quant.strategy in [
|
||||
QuantizationStrategy.TENSOR,
|
||||
QuantizationStrategy.CHANNEL,
|
||||
QuantizationStrategy.BLOCK,
|
||||
]
|
||||
return (
|
||||
is_symmetric_weight
|
||||
and is_static_weight
|
||||
and is_tensor_or_channel_or_block_weight
|
||||
)
|
||||
|
||||
def _is_wNa16_group_channel(
|
||||
self, weight_quant: QuantizationArgs, input_quant: QuantizationArgs
|
||||
) -> bool:
|
||||
input_quant_none = input_quant is None
|
||||
is_channel_group = (
|
||||
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
||||
)
|
||||
is_static = not weight_quant.dynamic
|
||||
|
||||
return is_channel_group and input_quant_none and is_static
|
||||
|
||||
def _get_scheme_from_parts(
|
||||
self,
|
||||
weight_quant: QuantizationArgs,
|
||||
input_quant: QuantizationArgs,
|
||||
format: str | None = None,
|
||||
) -> "CompressedTensorsScheme":
|
||||
# use the per-layer format if defined, otherwise, use global format
|
||||
format = format if format is not None else self.quant_format
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_fp4a16_nvfp4(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A16Fp4()
|
||||
|
||||
if self._is_fp8_w4a8_sm90(weight_quant, input_quant):
|
||||
return CompressedTensorsW4A8Fp8(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (
|
||||
format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
||||
):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size,
|
||||
)
|
||||
if (
|
||||
format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
if self._is_fp4a4_nvfp4(weight_quant, input_quant):
|
||||
if cutlass_fp4_supported() or envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return CompressedTensorsW4A4Fp4()
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Current platform does not support cutlass NVFP4."
|
||||
" Running CompressedTensorsW4A16Fp4."
|
||||
)
|
||||
return CompressedTensorsW4A16Fp4(has_input_global_scale=True)
|
||||
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
|
||||
)
|
||||
if is_fp8_w8a8_supported:
|
||||
return CompressedTensorsW8A8Fp8(
|
||||
weight_quant=weight_quant,
|
||||
is_static_input_scheme=(
|
||||
input_quant and not input_quant.dynamic
|
||||
),
|
||||
)
|
||||
else:
|
||||
# note: input_quant will be present for converted models;
|
||||
# will be ignored during inference post loading
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=not input_quant.dynamic,
|
||||
)
|
||||
|
||||
# note: input_quant can be None
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
is_static_input_scheme = input_quant and not input_quant.dynamic
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
)
|
||||
|
||||
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
|
||||
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Int8(
|
||||
strategy=weight_quant.strategy,
|
||||
is_static_input_scheme=False,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
|
||||
if self._is_dynamic_token_w4a8_int(weight_quant, input_quant):
|
||||
is_static_input_scheme = input_quant and not input_quant.dynamic
|
||||
return CompressedTensorsW4A8Int(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
group_size=weight_quant.group_size,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_quant.symmetric,
|
||||
)
|
||||
|
||||
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
|
||||
|
||||
def get_scheme(
|
||||
self, layer: torch.nn.Module, layer_name: str | None = None
|
||||
) -> Optional["CompressedTensorsScheme"]:
|
||||
"""
|
||||
compressed-tensors supports non uniform in the following way:
|
||||
|
||||
targets of config_groups: There can be N config_groups which each
|
||||
have a quantization scheme. Each config_group has a list of targets
|
||||
which can be a full layer_name, a regex for a layer_name, or
|
||||
an nn.Module name.
|
||||
|
||||
Detect whether a layer_name is found in any target and
|
||||
use the quantization scheme corresponding to the matched target
|
||||
to select the CompressedTensorsScheme used for inference.
|
||||
"""
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@kylesayrs): support ignore module names with ct matching utils
|
||||
if should_ignore_layer(
|
||||
layer_name, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
||||
):
|
||||
return None
|
||||
|
||||
# Will be empty for models with only sparsity
|
||||
weight_quant = input_quant = None
|
||||
if self.target_scheme_map:
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=self.target_scheme_map.keys(),
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
|
||||
scheme_dict = self.target_scheme_map[matched_target]
|
||||
weight_quant = scheme_dict.get("weights")
|
||||
input_quant = scheme_dict.get("input_activations")
|
||||
format = scheme_dict.get("format")
|
||||
|
||||
# Find the sparsity scheme of the layer
|
||||
# assume that fused layers inherit first component's sparsity scheme
|
||||
sparsity_targets = self.sparsity_scheme_map.keys() - set(
|
||||
self.sparsity_ignore_list
|
||||
)
|
||||
sparsity_scheme: SparsityCompressionConfig | None = None
|
||||
with suppress(ValueError):
|
||||
matched_target = find_matched_target(
|
||||
layer_name=layer_name,
|
||||
module=layer,
|
||||
targets=sparsity_targets,
|
||||
fused_mapping=self.packed_modules_mapping,
|
||||
)
|
||||
sparsity_scheme = self.sparsity_scheme_map[matched_target]
|
||||
|
||||
if self.supports_cutlass_24(
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme,
|
||||
):
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
model_compression_config = (
|
||||
None
|
||||
if sparsity_scheme is None or sparsity_scheme.format == "dense"
|
||||
else self.config
|
||||
)
|
||||
|
||||
scheme = CompressedTensors24(
|
||||
quantized=weight_quant is not None or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
model_compression_config=model_compression_config,
|
||||
)
|
||||
elif weight_quant is None:
|
||||
logger.warning_once(
|
||||
"Acceleration for non-quantized schemes is "
|
||||
"not supported by Compressed Tensors. "
|
||||
"Falling back to UnquantizedLinearMethod"
|
||||
)
|
||||
return None
|
||||
|
||||
else:
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_parts( # type: ignore
|
||||
weight_quant=weight_quant, input_quant=input_quant, format=format
|
||||
)
|
||||
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> str | None:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in compressed-tensors. If this is the case, return its equivalent
|
||||
param name expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
def has_blocked_weights(self) -> bool:
|
||||
for scheme in self.target_scheme_map.values():
|
||||
weight_quant = scheme.get("weights")
|
||||
if (
|
||||
weight_quant is not None
|
||||
and weight_quant.strategy == QuantizationStrategy.BLOCK
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def supports_cutlass_24(
|
||||
weight_quant: QuantizationArgs | None,
|
||||
input_quant: QuantizationArgs | None,
|
||||
sparsity_scheme: SparsityCompressionConfig | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the layer is supported by the Cutlass 2:4 Kernel
|
||||
Conditions:
|
||||
- Overarching condition: Sparsity Structure is 2:4
|
||||
- Unquantized cases are supported
|
||||
- Weight only quantization is not-supported
|
||||
- Supported weight quantization strategies are TENSOR and CHANNEL
|
||||
- Supported input quantization strategies are TENSOR and TOKEN
|
||||
- Only 8 bit quantization is supported
|
||||
|
||||
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
||||
False otherwise
|
||||
"""
|
||||
if sparsity_scheme is None:
|
||||
return False
|
||||
|
||||
is_valid_sparsity_structure: bool = (
|
||||
sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value
|
||||
)
|
||||
|
||||
valid_compressors = {
|
||||
CompressionFormat.dense.value,
|
||||
CompressionFormat.sparse_24_bitmask.value,
|
||||
}
|
||||
|
||||
is_valid_sparsity = (
|
||||
is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors
|
||||
)
|
||||
|
||||
if not is_valid_sparsity:
|
||||
return False
|
||||
|
||||
# Unquantized cases are supported
|
||||
if weight_quant is None and input_quant is None:
|
||||
return True
|
||||
|
||||
# Weight only quantization is not-supported
|
||||
if weight_quant is not None and input_quant is None:
|
||||
return False
|
||||
|
||||
supported_weight_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.CHANNEL.value,
|
||||
]
|
||||
|
||||
assert weight_quant is not None
|
||||
assert input_quant is not None
|
||||
if weight_quant.strategy not in supported_weight_quant_strategies:
|
||||
return False
|
||||
|
||||
supported_input_quant_strategies = [
|
||||
QuantizationStrategy.TENSOR.value,
|
||||
QuantizationStrategy.TOKEN.value,
|
||||
]
|
||||
|
||||
if input_quant.strategy not in supported_input_quant_strategies:
|
||||
return False
|
||||
|
||||
return weight_quant.num_bits == input_quant.num_bits == 8
|
||||
|
||||
|
||||
class CompressedTensorsLinearMethod(LinearMethodBase):
|
||||
def __init__(self, quantization_config: CompressedTensorsConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from compressed-tensors
|
||||
checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_scheme(kv_cache_scheme: dict[str, Any] | None):
|
||||
"""
|
||||
Validator for the kv cache scheme. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_scheme: the compressed-tensors kv cache scheme
|
||||
"""
|
||||
if kv_cache_scheme is None:
|
||||
return
|
||||
|
||||
type_ = kv_cache_scheme.get("type")
|
||||
num_bits = kv_cache_scheme.get("num_bits")
|
||||
|
||||
if type_ != "float" and num_bits != 8:
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
"num_bits=8, type=float, however "
|
||||
f"received num_bits={num_bits}, type={type_}"
|
||||
)
|
||||
|
||||
strategy = kv_cache_scheme.get("strategy")
|
||||
if strategy != "tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"Expected strategy: tensor, found strategy: {strategy}"
|
||||
)
|
||||
|
||||
is_symmetric = kv_cache_scheme.get("symmetric")
|
||||
if not is_symmetric:
|
||||
raise NotImplementedError(
|
||||
"Only support symmetric scaling factor "
|
||||
"for compressed-tensors KV cache. "
|
||||
f"However found symmetric: {is_symmetric}"
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,35 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_24 import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
)
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||
|
||||
# This avoids circular import error
|
||||
from .compressed_tensors_24 import CompressedTensors24 # isort: skip
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24",
|
||||
"CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A4Fp4",
|
||||
"CompressedTensorsW4A8Int",
|
||||
"CompressedTensorsW4A8Fp8",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,392 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat, ModelCompressor
|
||||
from compressed_tensors.quantization import (
|
||||
QuantizationArgs,
|
||||
QuantizationStrategy,
|
||||
QuantizationType,
|
||||
)
|
||||
from compressed_tensors.utils import combine_shards
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
sparse_cutlass_supported,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensors24"]
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class CompressedTensors24(CompressedTensorsScheme):
|
||||
def __init__(
|
||||
self,
|
||||
quantized: bool = False,
|
||||
weight_quant: QuantizationArgs | None = None,
|
||||
input_quant: QuantizationArgs | None = None,
|
||||
model_compression_config: dict[str, Any] | None = None,
|
||||
):
|
||||
self.quantized = quantized
|
||||
self.weight_quant = weight_quant
|
||||
self.input_quant = input_quant
|
||||
model_compressor = ModelCompressor.from_compression_config(
|
||||
model_compression_config
|
||||
)
|
||||
self.do_sparse_decompress = (
|
||||
model_compressor is not None
|
||||
and model_compressor.sparsity_config.format
|
||||
== CompressionFormat.sparse_24_bitmask.value
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
self.model_compressor = model_compressor
|
||||
|
||||
if (
|
||||
quantized
|
||||
and input_quant is not None
|
||||
and self._get_quant_dtype() == current_platform.fp8_dtype()
|
||||
):
|
||||
static = not input_quant.dynamic
|
||||
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
|
||||
self.quant_fp8 = QuantFP8(static, g_shape)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# Only cutlass 3.x kernels are implemented so far
|
||||
return 90
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
if not sparse_cutlass_supported():
|
||||
raise ValueError(
|
||||
"Sparse CUTLASS not supported. vLLM must be built with "
|
||||
"CUDA 12.2 or later to use this feature"
|
||||
)
|
||||
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size = input_size
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
|
||||
|
||||
# parameter to store uncompressed weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
if self.do_sparse_decompress:
|
||||
assert all(
|
||||
partition_size % 8 == 0 for partition_size in output_partition_sizes
|
||||
), "All partitions must be divisible by 8 for "
|
||||
"2:4 sparse compressed models"
|
||||
|
||||
shape = BasevLLMParameter(
|
||||
data=torch.empty(2, 1, dtype=torch.int64),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
compressed_weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=self.weights_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
bitmask = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 8,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("shape", shape)
|
||||
layer.register_parameter("compressed", compressed_weight)
|
||||
layer.register_parameter("bitmask", bitmask)
|
||||
|
||||
# Check if quantized, not just 2:4 Sparse
|
||||
if self.quantized:
|
||||
if (
|
||||
self.weight_quant
|
||||
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
||||
):
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
(sum(output_partition_sizes), 1), dtype=torch.float32
|
||||
),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.weight_quant
|
||||
and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
)
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# input quant will be non-none
|
||||
if self.input_quant and not self.input_quant.dynamic:
|
||||
# register input quant scale
|
||||
assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
else:
|
||||
# for sparse-only, pass in 1 for weight/input scales
|
||||
weight_scale = torch.nn.Parameter(
|
||||
data=torch.ones(1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
input_scale = torch.nn.Parameter(
|
||||
data=torch.ones(1, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
"""
|
||||
Compress weights after loading. Store compressed weight and meta
|
||||
tensor
|
||||
|
||||
:post-condition: layer.w_compressed and layer.meta are
|
||||
set to the compressed weight and meta tensor in the
|
||||
format expected by the Cutlass kernels
|
||||
:param layer: The layer with the weights to be processed
|
||||
|
||||
"""
|
||||
if self.do_sparse_decompress:
|
||||
layer.weight.data = self._decompress_bitmask_compressed_weight(
|
||||
compressed=layer.compressed,
|
||||
bitmask=layer.bitmask,
|
||||
layer=layer,
|
||||
)
|
||||
|
||||
# compressed and bitmask tensors
|
||||
# are no longer needed after decompression
|
||||
del layer.compressed
|
||||
del layer.bitmask
|
||||
|
||||
# torch.compile workaround
|
||||
if hasattr(layer, "input_scale"):
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
if self.weight_quant:
|
||||
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
convert_to_channelwise(
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
# torch.compile workaround
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Set all negative zero values to 0 prior to compression
|
||||
if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2:
|
||||
layer.weight.data[layer.weight.data == -0.0] = 0.0
|
||||
|
||||
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
|
||||
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
|
||||
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns the output tensor for the layer with 2:4
|
||||
sparse compressed weights, given the input tensor
|
||||
and bias
|
||||
|
||||
:param layer: The layer with 2:4 sparse compressed
|
||||
weights to be used for the computation
|
||||
:param x: The input tensor to the layer
|
||||
:param bias: The bias to be added to the output tensor
|
||||
:return: The output tensor of the layer
|
||||
"""
|
||||
if self.quantized:
|
||||
scale = getattr(layer, "input_scale", None)
|
||||
|
||||
if self.weights_dtype == torch.int8:
|
||||
ops_output = ops.scaled_int8_quant(x, scale=scale)
|
||||
q_input = ops_output[0]
|
||||
input_scale = ops_output[1]
|
||||
else:
|
||||
assert self.weights_dtype == torch.float8_e4m3fn
|
||||
q_input, input_scale = self.quant_fp8(x, scale=scale)
|
||||
|
||||
else:
|
||||
# Not quantized, nothing to do with the input_scales, use as is
|
||||
input_scale = layer.input_scale
|
||||
q_input = x
|
||||
|
||||
out = ops.cutlass_scaled_sparse_mm(
|
||||
a=q_input,
|
||||
bt_nzs=layer.weight,
|
||||
bt_meta=layer.meta,
|
||||
scale_a=input_scale,
|
||||
scale_b=layer.weight_scale,
|
||||
out_dtype=x.dtype,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert out.is_contiguous()
|
||||
return out
|
||||
|
||||
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
|
||||
if not self.quantized:
|
||||
return params_dtype
|
||||
return self._get_quant_dtype()
|
||||
|
||||
def _get_quant_dtype(self) -> torch.dtype:
|
||||
assert self.quantized
|
||||
assert self.weight_quant is not None
|
||||
assert self.input_quant is not None
|
||||
|
||||
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
|
||||
|
||||
if not is_8_bits:
|
||||
raise ValueError("Cutlass only supports 8-bit quantization")
|
||||
|
||||
if (
|
||||
self.weight_quant.type == QuantizationType.FLOAT
|
||||
and self.input_quant.type == QuantizationType.FLOAT
|
||||
):
|
||||
return torch.float8_e4m3fn
|
||||
|
||||
if (
|
||||
self.weight_quant.type == QuantizationType.INT
|
||||
and self.input_quant.type == QuantizationType.INT
|
||||
):
|
||||
return torch.int8
|
||||
|
||||
raise ValueError("Quantization type not supported by Cutlass")
|
||||
|
||||
def _decompress_bitmask_compressed_weight(
|
||||
self,
|
||||
compressed: torch.Tensor,
|
||||
bitmask: torch.Tensor,
|
||||
layer: torch.nn.Module,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
|
||||
return the result.
|
||||
|
||||
This function also supports sharded decompression.
|
||||
|
||||
:param compressed: The 2:4 sparse weight tensor compressed using the
|
||||
sparse-24-bitmask compressor. This is different from
|
||||
`cutlass_sparse_compress` which uses a different scheme (2 bits for
|
||||
every nonzero element that represent the coordinate within the block
|
||||
of 4). The bitmask compression here uses a bitmask to indicate the
|
||||
positions of non-zero elements.
|
||||
:param bitmask: The 2:4 bitmask associated with the compressed weights,
|
||||
representing the positions of non-zero elements in the compressed
|
||||
tensor.
|
||||
:param layer: The layer whose weights need to be processed after
|
||||
loading.
|
||||
:return: The decompressed 2:4 sparse weight tensor.
|
||||
"""
|
||||
|
||||
sparsity_compressor = self.model_compressor.sparsity_compressor
|
||||
|
||||
def _process_split(
|
||||
bitmask_compressed_weight: torch.Tensor,
|
||||
shape,
|
||||
bitmask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
weight_data = dict(
|
||||
compressed=bitmask_compressed_weight,
|
||||
shape=shape,
|
||||
bitmask=bitmask,
|
||||
)
|
||||
return sparsity_compressor.decompress_weight(weight_data)
|
||||
|
||||
split_weights: list[torch.Tensor] = []
|
||||
split_bitmask: list[torch.Tensor] = []
|
||||
split_shape: list[tuple[int, int]] = []
|
||||
|
||||
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
split_weights = torch.split(compressed, layer.logical_widths)
|
||||
split_bitmask = torch.split(bitmask, layer.logical_widths)
|
||||
split_shape = [
|
||||
(out, layer.input_size_per_partition) for out in layer.logical_widths
|
||||
]
|
||||
|
||||
if split_weights:
|
||||
decompressed_shards = [
|
||||
_process_split(compressed_weight, shape, bitmask)
|
||||
for compressed_weight, shape, bitmask in zip(
|
||||
split_weights, split_shape, split_bitmask
|
||||
)
|
||||
]
|
||||
decompressed = combine_shards(decompressed_shards)
|
||||
else:
|
||||
decompressed = sparsity_compressor.decompress_weight(
|
||||
dict(
|
||||
compressed=compressed,
|
||||
shape=(
|
||||
layer.logical_widths[0],
|
||||
layer.input_size_per_partition,
|
||||
),
|
||||
bitmask=bitmask,
|
||||
)
|
||||
)
|
||||
return decompressed
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["CompressedTensorsScheme"]
|
||||
|
||||
|
||||
class CompressedTensorsScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by CompressedTensors.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
}
|
||||
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, num_bits: int, group_size: int | None = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.tile_size = 16
|
||||
|
||||
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}"
|
||||
)
|
||||
|
||||
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError("group_size must be given when using strategy group")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere + up
|
||||
return 80
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False)
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.tile_size // 2,
|
||||
output_size_per_partition * self.tile_size // pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=pack_factor,
|
||||
marlin_tile_size=self.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
input_groups = (
|
||||
1
|
||||
if self.group_size is None
|
||||
else input_size_per_partition // self.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if self.group_size is not None:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
else:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
meta = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", qweight)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("scale_packed", scales)
|
||||
layer.register_parameter("meta", meta)
|
||||
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
workspace = Parameter(
|
||||
torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.weight_packed
|
||||
meta = layer.meta
|
||||
scales = layer.scale_packed
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(
|
||||
x_2d,
|
||||
qweight,
|
||||
meta,
|
||||
scales,
|
||||
workspace,
|
||||
self.quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -0,0 +1,124 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
apply_fp4_marlin_linear,
|
||||
prepare_fp4_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Fp4(CompressedTensorsScheme):
|
||||
def __init__(self, has_input_global_scale: bool = False):
|
||||
self.has_input_global_scale = has_input_global_scale
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# dont restrict as emulations
|
||||
return 80
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
if self.has_input_global_scale:
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# Process parameters for marlin repacking
|
||||
|
||||
# Rename weight_packed to weight that marlin expects
|
||||
layer.weight = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
del layer.weight_packed
|
||||
# Rename weight_global_scale to weight_scale_2 that marlin expects
|
||||
# Note: ct stores the inverse of what is expected by the marlin kernel
|
||||
layer.weight_scale_2 = Parameter(
|
||||
1 / layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
del layer.weight_global_scale
|
||||
|
||||
if self.has_input_global_scale:
|
||||
layer.input_global_scale = torch.nn.Parameter(
|
||||
layer.input_global_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
prepare_fp4_layer_for_marlin(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp4_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
weight_scale_2=layer.weight_scale_2,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,218 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import ( # noqa: E501
|
||||
run_nvfp4_emulations,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
cutlass_fp4_supported,
|
||||
swizzle_blockscale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm, has_flashinfer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A4Fp4"]
|
||||
|
||||
|
||||
class CompressedTensorsW4A4Fp4(CompressedTensorsScheme):
|
||||
def __init__(self):
|
||||
self.backend = "none"
|
||||
if envs.VLLM_NVFP4_GEMM_BACKEND is None:
|
||||
if has_flashinfer():
|
||||
self.backend = "flashinfer-cutlass"
|
||||
elif cutlass_fp4_supported():
|
||||
self.backend = "cutlass"
|
||||
elif envs.VLLM_USE_FBGEMM:
|
||||
self.backend = "fbgemm"
|
||||
try:
|
||||
import fbgemm_gpu # noqa: F401
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Backend fbgemm requires fbgemm.f4f4bf16 operator, "
|
||||
"Please install with: pip install fbgemm-gpu-genai"
|
||||
) from exc
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"):
|
||||
self.backend = envs.VLLM_NVFP4_GEMM_BACKEND
|
||||
assert has_flashinfer(), f"FlashInfer is required for {self.backend}"
|
||||
elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass":
|
||||
self.backend = "cutlass"
|
||||
assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}"
|
||||
|
||||
if self.backend == "none":
|
||||
raise ValueError(
|
||||
"No valid NVFP4 GEMM backend found. "
|
||||
"Please check your platform capability."
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.backend} for NVFP4 GEMM")
|
||||
self.group_size = 16
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
return 80
|
||||
return 100
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
|
||||
# Weight
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
|
||||
# Global Weight Scale
|
||||
weight_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_global_scale", weight_global_scale)
|
||||
|
||||
# Per Group Weight Scale
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition // self.group_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
input_global_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_global_scale", input_global_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
global_input_scale = layer.input_global_scale.max().to(torch.float32)
|
||||
layer.input_global_scale = Parameter(global_input_scale, requires_grad=False)
|
||||
|
||||
layer.weight_global_scale = Parameter(
|
||||
layer.weight_global_scale.max().to(torch.float32), requires_grad=False
|
||||
)
|
||||
|
||||
if self.backend == "flashinfer-trtllm":
|
||||
# FlashInfer TRTLLM FP4 GEMM requires a different weight layout.
|
||||
# FlashInfer provides nvfp4_quantize to quantize + shuffle the
|
||||
# layout but we use our own quantization so we have to call
|
||||
# shuffles ourselves.
|
||||
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
|
||||
|
||||
weight = layer.weight_packed.data
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
epilogue_tile_m = 128
|
||||
weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m)
|
||||
weight_scale = (
|
||||
shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m)
|
||||
.reshape(weight_scale.shape)
|
||||
.view(torch.float8_e4m3fn)
|
||||
)
|
||||
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
layer.weight_packed = Parameter(weight, requires_grad=False)
|
||||
else:
|
||||
swizzled_weight_scale = swizzle_blockscale(layer.weight_scale)
|
||||
if self.backend == "fbgemm":
|
||||
swizzled_weight_scale = swizzled_weight_scale.view(-1).view(torch.uint8)
|
||||
layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False)
|
||||
layer.weight_packed = Parameter(
|
||||
layer.weight_packed.data, requires_grad=False
|
||||
)
|
||||
|
||||
layer.alpha = Parameter(
|
||||
1 / (layer.input_global_scale * layer.weight_global_scale),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if envs.VLLM_USE_NVFP4_CT_EMULATIONS:
|
||||
out = run_nvfp4_emulations(
|
||||
x=x,
|
||||
input_global_scale=layer.input_global_scale,
|
||||
weight=layer.weight_packed,
|
||||
weight_scale_swizzled=layer.weight_scale,
|
||||
weight_global_scale=layer.weight_global_scale,
|
||||
)
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out
|
||||
|
||||
output_dtype = x.dtype
|
||||
output_shape = [x.shape[0], layer.weight_packed.shape[0]]
|
||||
|
||||
# quantize BF16 or FP16 to (FP4 and interleaved block scale)
|
||||
x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_global_scale)
|
||||
|
||||
mm_args = (
|
||||
x_fp4,
|
||||
layer.weight_packed,
|
||||
x_blockscale,
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
output_dtype,
|
||||
)
|
||||
if self.backend.startswith("flashinfer-"):
|
||||
backend_name = self.backend[len("flashinfer-") :]
|
||||
out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name)
|
||||
elif self.backend == "fbgemm":
|
||||
out = torch.ops.fbgemm.f4f4bf16(
|
||||
x_fp4,
|
||||
layer.weight_packed,
|
||||
x_blockscale.view(-1).view(torch.uint8),
|
||||
layer.weight_scale,
|
||||
layer.alpha,
|
||||
use_mx=False,
|
||||
).to(output_dtype)
|
||||
else:
|
||||
assert self.backend == "cutlass"
|
||||
out = cutlass_scaled_fp4_mm(*mm_args)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
return out.view(*output_shape)
|
||||
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Fp8"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Fp8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size != 128 or self.strategy != "group":
|
||||
raise ValueError(
|
||||
"W4A8 kernels require group quantization with group size 128"
|
||||
)
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}"
|
||||
)
|
||||
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# hopper
|
||||
return 90
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=torch.float8_e4m3fn, # always use fp8(e4m3)
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx,
|
||||
out_type=params_dtype,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Fp8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel
|
||||
)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
|
||||
# TODO(czhu): allocate the packed fp8 scales memory here?
|
||||
# the scales will be expanded by 8x via `cutlass_pack_scale_fp8`
|
||||
weight_scale_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0, input_dim=1, **weight_scale_args
|
||||
)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
# per-channel scales
|
||||
weight_chan_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((output_size_per_partition, 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("weight_chan_scale", weight_chan_scale)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx",
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,153 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsW4A8Int"]
|
||||
W4A8_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.int4,
|
||||
}
|
||||
W4A8_SUPPORTED_BITS = list(W4A8_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A8Int(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: int | None = None,
|
||||
is_static_input_scheme: bool = False,
|
||||
input_symmetric: bool = True,
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
if num_bits not in W4A8_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}."
|
||||
f"Supported num_bits = {W4A8_SUPPORTED_TYPES_MAP.keys()}"
|
||||
)
|
||||
self.quant_type = W4A8_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
|
||||
# Compute effective group_size
|
||||
if self.group_size == -1:
|
||||
effective_group_size = (
|
||||
input_size_per_partition if row_parallel else input_size
|
||||
)
|
||||
else:
|
||||
effective_group_size = self.group_size
|
||||
|
||||
# Ensure group_size divides input_size_per_partition
|
||||
assert input_size_per_partition % effective_group_size == 0, (
|
||||
f"input_size_per_partition {input_size_per_partition}"
|
||||
f" not divisible by group_size {effective_group_size}"
|
||||
)
|
||||
|
||||
# Determine scale partitioning
|
||||
is_channelwise = self.group_size == -1
|
||||
repeat_scales = is_channelwise and row_parallel
|
||||
partition_scales = not repeat_scales
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=effective_group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=False,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW4A8Int", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
scales_and_zp_size = input_size_per_partition // effective_group_size
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.empty(
|
||||
output_size_per_partition, scales_and_zp_size, dtype=params_dtype
|
||||
),
|
||||
}
|
||||
|
||||
if partition_scales:
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0, input_dim=1, **weight_scale_args
|
||||
)
|
||||
else:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name=None,
|
||||
w_gidx_param_name=None,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
|
||||
|
||||
|
||||
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
||||
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
||||
# we expand each scale to its shard's channels.
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
ws_channelwise = convert_to_channelwise(
|
||||
layer.weight_scale, layer.logical_widths
|
||||
)
|
||||
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
|
||||
else:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
|
||||
# Weights must be transposed for marlin
|
||||
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
|
||||
|
||||
if self.is_static_input_scheme:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.input_scale = torch.nn.Parameter(
|
||||
layer.input_scale.data, requires_grad=False
|
||||
)
|
||||
prepare_fp8_layer_for_marlin(layer)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
elif self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported weight strategy={self.strategy}, "
|
||||
f"supported strategies are {SUPPORTED_STRATEGIES}"
|
||||
)
|
||||
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE (to deal with converted checkpoints)
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_fp8_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
workspace=layer.workspace,
|
||||
size_n=layer.output_size_per_partition,
|
||||
size_k=layer.input_size_per_partition,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
W8A8BlockFp8LinearOp,
|
||||
create_fp8_input_scale,
|
||||
create_fp8_scale_parameter,
|
||||
create_fp8_weight_parameter,
|
||||
maybe_post_process_fp8_weight_block,
|
||||
process_fp8_weight_block_strategy,
|
||||
process_fp8_weight_channel_strategy,
|
||||
process_fp8_weight_tensor_strategy,
|
||||
validate_fp8_block_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
cutlass_block_fp8_supported,
|
||||
maybe_create_device_identity,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
__all__ = ["CompressedTensorsW8A8Fp8"]
|
||||
|
||||
strategy_to_parameter_type = {
|
||||
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
|
||||
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
|
||||
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
|
||||
}
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
||||
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
|
||||
self.weight_quant = weight_quant
|
||||
self.strategy = weight_quant.strategy
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
self.weight_block_size = self.weight_quant.block_structure
|
||||
if self.weight_block_size is not None:
|
||||
self.act_q_group_shape = GroupShape(1, self.weight_block_size[0])
|
||||
else:
|
||||
self.act_q_group_shape = (
|
||||
GroupShape.PER_TENSOR
|
||||
if is_static_input_scheme
|
||||
else GroupShape.PER_TOKEN
|
||||
)
|
||||
|
||||
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported()
|
||||
self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled()
|
||||
|
||||
if self.weight_block_size is not None:
|
||||
assert not self.is_static_input_scheme
|
||||
self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp(
|
||||
weight_group_shape=GroupShape(*self.weight_block_size),
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported,
|
||||
use_aiter_and_is_supported=self.use_aiter_and_is_supported,
|
||||
)
|
||||
else:
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_q_group_shape,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
maybe_create_device_identity()
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
layer.weight_block_size = None
|
||||
layer.orig_dtype = params_dtype
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.weight_block_size is not None
|
||||
layer.weight_block_size = self.weight_block_size
|
||||
# Validate block quantization shapes
|
||||
validate_fp8_block_shape(
|
||||
layer,
|
||||
input_size,
|
||||
output_size,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
self.weight_block_size,
|
||||
)
|
||||
|
||||
# WEIGHT
|
||||
weight = create_fp8_weight_parameter(
|
||||
output_size_per_partition, input_size_per_partition, weight_loader
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = create_fp8_scale_parameter(
|
||||
strategy_to_parameter_type[self.strategy],
|
||||
output_partition_sizes,
|
||||
input_size_per_partition,
|
||||
layer.weight_block_size,
|
||||
weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = create_fp8_input_scale(output_partition_sizes, weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_tensor_strategy(
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
layer.logical_widths,
|
||||
getattr(layer, "input_scale", None),
|
||||
)
|
||||
weight = weight.t()
|
||||
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight, weight_scale, input_scale = process_fp8_weight_channel_strategy(
|
||||
layer.weight, layer.weight_scale, getattr(layer, "input_scale", None)
|
||||
)
|
||||
weight = weight.t()
|
||||
|
||||
elif self.strategy == QuantizationStrategy.BLOCK:
|
||||
assert self.is_static_input_scheme is False
|
||||
weight, weight_scale = process_fp8_weight_block_strategy(
|
||||
layer.weight, layer.weight_scale
|
||||
)
|
||||
input_scale = None
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight = Parameter(weight.data, requires_grad=False)
|
||||
layer.weight_scale = Parameter(weight_scale.data, requires_grad=False)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale.data, requires_grad=False)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
if self.strategy == QuantizationStrategy.BLOCK:
|
||||
maybe_post_process_fp8_weight_block(layer)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.weight_block_size is not None:
|
||||
return self.w8a8_block_fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
import vllm.envs as env
|
||||
if env.VLLM_MIX_QUANTIZATION_TYPE == "TENSOR":
|
||||
self.strategy = QuantizationStrategy.TENSOR
|
||||
elif env.VLLM_MIX_QUANTIZATION_TYPE == "CHANNEL":
|
||||
self.strategy = QuantizationStrategy.CHANNEL
|
||||
else:
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.strategy == QuantizationStrategy.CHANNEL),
|
||||
is_static_input_scheme=self.is_static_input_scheme,
|
||||
input_symmetric=self.input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
remainder = input_size_per_partition % 64
|
||||
if remainder != 0:
|
||||
input_size_per_partition_padded = input_size_per_partition + (64 - remainder)
|
||||
else:
|
||||
input_size_per_partition_padded = input_size_per_partition
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition_padded,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj",
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4b8, 8: scalar_types.uint8b128}
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: int | None = None,
|
||||
symmetric: bool | None = True,
|
||||
actorder: ActivationOrdering | None = None,
|
||||
):
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size == -1 and self.strategy != "channel":
|
||||
raise ValueError(
|
||||
"Marlin kernels require group quantization or "
|
||||
"channelwise quantization, but found no group "
|
||||
"size and strategy is not channelwise."
|
||||
)
|
||||
|
||||
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}"
|
||||
)
|
||||
|
||||
self.quant_type = (
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||
if not self.symmetric
|
||||
else WNA16_SUPPORTED_TYPES_MAP[num_bits]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for CompressedTensorsWNA16", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = input_size != input_size_per_partition
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel
|
||||
)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
}
|
||||
|
||||
zeros_args = {
|
||||
"weight_loader": weight_loader,
|
||||
"data": torch.zeros(
|
||||
output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0, **weight_scale_args)
|
||||
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args,
|
||||
)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
output_dim=0, input_dim=1, **weight_scale_args
|
||||
)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args,
|
||||
)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="weight_packed",
|
||||
w_s_param_name="weight_scale",
|
||||
w_zp_param_name="weight_zero_point",
|
||||
w_gidx_param_name="weight_g_idx",
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,260 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Callable, Generator
|
||||
from itertools import accumulate
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (
|
||||
TransformArgs,
|
||||
TransformConfig,
|
||||
TransformLocation,
|
||||
TransformScheme,
|
||||
)
|
||||
from compressed_tensors.utils import is_match
|
||||
|
||||
from vllm.model_executor.layers.linear import (
|
||||
WEIGHT_LOADER_V2_SUPPORTED,
|
||||
LinearMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.module import ( # noqa: E501
|
||||
HadamardTransform,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple,
|
||||
)
|
||||
|
||||
|
||||
class CompressedTensorsLinearTransformMethod(LinearMethodBase):
|
||||
"""
|
||||
Wraps `CompressedTensorsLinearMethod` or `UnquantizedLinearMethod` and adds
|
||||
input and output transforms to either side of the original apply method
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_schemes(
|
||||
cls,
|
||||
quant_method: LinearMethodBase,
|
||||
quant_scheme: CompressedTensorsScheme | None,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple],
|
||||
) -> "CompressedTensorsLinearTransformMethod":
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.schemes.linear_qutlass_nvfp4 import ( # noqa: E501
|
||||
QutlassNvFP4LinearMethod,
|
||||
is_qutlass_fp4_scheme,
|
||||
)
|
||||
|
||||
assert input_tfms or output_tfms
|
||||
|
||||
if is_qutlass_fp4_scheme(quant_scheme, input_tfms):
|
||||
return QutlassNvFP4LinearMethod(quant_method, input_tfms, output_tfms)
|
||||
|
||||
# hadacore or dense gemm is selected by Transform module
|
||||
|
||||
return cls(quant_method, input_tfms, output_tfms)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
quant_method: LinearMethodBase,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
output_tfms: dict[int, TransformTuple],
|
||||
):
|
||||
self.quant_method = quant_method
|
||||
self.input_tfms = input_tfms
|
||||
self.output_tfms = output_tfms
|
||||
|
||||
self.input_transform: HadamardTransform | None = None
|
||||
self.output_transform: HadamardTransform | None = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# get weight loader for transforms
|
||||
weight_loader: Callable = extra_weight_attrs.get("weight_loader") # type: ignore[assignment]
|
||||
|
||||
# HACK: UnquantizedLinearMethod does not support weight loader v2, but
|
||||
# transforms (specifically SharedWeightParameter) requires
|
||||
# weight loader v2. Until UnquantizedLinearMethod supports v2, we must
|
||||
# hack around this by getting weight loader v1 so ULM can load correctly
|
||||
quant_method_name = self.quant_method.__class__.__name__
|
||||
if quant_method_name not in WEIGHT_LOADER_V2_SUPPORTED:
|
||||
weight_loader_v1 = layer.weight_loader
|
||||
extra_weight_attrs["weight_loader"] = weight_loader_v1
|
||||
|
||||
self.quant_method.create_weights(
|
||||
layer=layer,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
input_size=input_size,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
# validate schemes
|
||||
num_partitions = len(output_partition_sizes)
|
||||
self._validate_tfm_schemes(num_partitions)
|
||||
|
||||
# create submodules for weight loading
|
||||
if len(self.input_tfms) > 0:
|
||||
scheme_name = list(self.input_tfms.values())[0].scheme_name
|
||||
location = list(self.input_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(
|
||||
self.input_tfms,
|
||||
layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.input_transform = transform
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
transform_name = f"{scheme_name}_{location}"
|
||||
|
||||
transform = HadamardTransform(
|
||||
self.output_tfms,
|
||||
layer,
|
||||
weight_loader,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
)
|
||||
layer.register_module(transform_name, transform)
|
||||
self.output_transform = transform
|
||||
|
||||
# compute partition ranges for slicing activations
|
||||
starts = [0] + list(accumulate(output_partition_sizes))[:-1]
|
||||
self.partition_ranges = list(zip(starts, output_partition_sizes))
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
self.quant_method.process_weights_after_loading(layer)
|
||||
|
||||
for submodule in layer.children():
|
||||
if isinstance(submodule, HadamardTransform):
|
||||
submodule.process_weights_after_loading()
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.input_transform is not None:
|
||||
x = self.input_transform(x)
|
||||
|
||||
assert bias is None
|
||||
x = self.quant_method.apply(layer, x, bias)
|
||||
|
||||
# In most cases, input transforms are preferred over output transforms
|
||||
# (@ksayers): confirm that this is done concurrently
|
||||
if self.output_transform is not None:
|
||||
for part_id, (start, length) in enumerate(self.partition_ranges):
|
||||
x[:, start : start + length] = self.output_transform(
|
||||
x[:, start : start + length].clone(), part_id=part_id
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def _validate_tfm_schemes(self, num_partitions: int):
|
||||
if len(self.input_tfms) > 0:
|
||||
if 0 not in self.input_tfms:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
for part_index in range(num_partitions):
|
||||
if self.input_tfms[part_index] != self.input_tfms[0]:
|
||||
raise ValueError("Must have same input")
|
||||
|
||||
if len(self.output_tfms) > 0:
|
||||
scheme_name = list(self.output_tfms.values())[0].scheme_name
|
||||
location = list(self.output_tfms.values())[0].args.location
|
||||
|
||||
for tfm in self.output_tfms.values():
|
||||
if tfm.scheme_name != scheme_name:
|
||||
raise ValueError("Must have same scheme name")
|
||||
if tfm.args.location != location:
|
||||
raise ValueError("Must have same location")
|
||||
|
||||
return self.input_tfms, self.output_tfms
|
||||
|
||||
|
||||
def get_linear_transform_schemes(
|
||||
layer: torch.nn.Module,
|
||||
layer_name: str,
|
||||
transform_config: TransformConfig | None,
|
||||
packed_modules_mapping: dict[str, list[str]],
|
||||
) -> tuple[
|
||||
dict[int, TransformTuple], dict[int, TransformTuple]
|
||||
]: # [input_transform, [output_transform, ...]]
|
||||
# there can only be one transform input scheme per (fused) module
|
||||
input_tfms = {}
|
||||
output_tfms = {}
|
||||
|
||||
partition_names = get_layer_partition_names(layer_name, packed_modules_mapping)
|
||||
|
||||
for scheme_name, scheme, args in get_schemes_args(transform_config):
|
||||
for part_index, part_name in enumerate(partition_names):
|
||||
if (
|
||||
is_match(part_name, layer, args.targets, args.ignore)
|
||||
and args.is_online()
|
||||
):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
input_tfms[part_index] = TransformTuple(scheme_name, scheme, args)
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
output_tfms[part_index] = TransformTuple(scheme_name, scheme, args)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot apply `{args.location}` transform to `{layer_name}`"
|
||||
)
|
||||
|
||||
return (input_tfms, output_tfms)
|
||||
|
||||
|
||||
def get_schemes_args(
|
||||
transform_config: TransformConfig | None,
|
||||
) -> Generator[tuple[str, TransformScheme, TransformArgs]]:
|
||||
if transform_config is None:
|
||||
return
|
||||
|
||||
for scheme_name, scheme in transform_config.config_groups.items():
|
||||
for args in scheme.apply:
|
||||
yield (scheme_name, scheme, args)
|
||||
|
||||
|
||||
def get_layer_partition_names(
|
||||
layer_name: str, packed_modules_mapping: dict[str, list[str]]
|
||||
) -> list[str]:
|
||||
"""
|
||||
Get all partition names associated with this layer.
|
||||
Names are returned in order of their partition indices.
|
||||
|
||||
```python
|
||||
mapping = {"gate_up_proj", "gate_proj", "up_proj"}
|
||||
|
||||
assert get_layer_partition_names("mlp.gate_up_proj", mapping) == [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
]
|
||||
assert get_layer_partition_names("mlp.down_proj", mapping) == ["down_proj"]"""
|
||||
for fused_suffix, part_suffixes in packed_modules_mapping.items():
|
||||
if layer_name.endswith(fused_suffix):
|
||||
return [
|
||||
layer_name.removesuffix(fused_suffix) + part_suffix
|
||||
for part_suffix in part_suffixes
|
||||
]
|
||||
|
||||
return [layer_name]
|
||||
@@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Callable, Hashable
|
||||
|
||||
import torch
|
||||
from compressed_tensors.transform import (
|
||||
TransformArgs,
|
||||
TransformLocation,
|
||||
TransformScheme,
|
||||
)
|
||||
from torch import Tensor
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import LinearBase
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.utils import ( # noqa: E501
|
||||
TransformTuple,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import dispatch_unquantized_gemm
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.parameter import SharedWeightParameter
|
||||
|
||||
|
||||
class HadamardTransform(torch.nn.Module):
|
||||
"""
|
||||
Class which handles weight loading, postprocessing, and application of
|
||||
transforms. Meant to be used with `CompressedTensorsLinearTransformMethod`
|
||||
and attention transforms method (not implemented yet)
|
||||
"""
|
||||
|
||||
transforms: dict[int, TransformTuple] # info parsed from transforms config
|
||||
weight: SharedWeightParameter # container for shared tensors
|
||||
|
||||
scales: dict[int, float] # hadamard scale, usually sqrt(matrix.size(0))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transforms: dict[int, TransformTuple],
|
||||
layer: torch.nn.Module,
|
||||
weight_loader: Callable,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.transforms = transforms
|
||||
self.scales = {}
|
||||
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
raise NotImplementedError(
|
||||
"Online transforms with tensor parallelism is not supported"
|
||||
)
|
||||
|
||||
# Similar to row/col parallel params, but tensors are separate
|
||||
# to allow for loading with shared memory
|
||||
self.weight = SharedWeightParameter(weight_loader=weight_loader)
|
||||
|
||||
# create shared partition data for each partition of the original weight
|
||||
input_size = input_size_per_partition
|
||||
for part_index, (_scheme_name, scheme, args) in self.transforms.items():
|
||||
output_size = output_partition_sizes[part_index]
|
||||
weight_size = self._get_weight_size(
|
||||
layer, scheme, args, input_size, output_size
|
||||
)
|
||||
|
||||
data_key = self._get_data_key(scheme, weight_size)
|
||||
self.weight.add_partition(
|
||||
part_index,
|
||||
data_key,
|
||||
size=(weight_size, weight_size),
|
||||
dtype=scheme.precision,
|
||||
)
|
||||
|
||||
# validate that shared tensors and schemes are correct
|
||||
self._validate_input_transforms()
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
for part_id in self.weight.partitions:
|
||||
data = self.weight.partitions[part_id].data
|
||||
|
||||
# required by torch.compile
|
||||
self.weight.process_weights_after_loading()
|
||||
|
||||
# precompute scale as a runtime multiply, not division
|
||||
# do not fold into weight in order to utilize FWHT
|
||||
self.scales[part_id] = 1 / math.sqrt(data.size(0))
|
||||
|
||||
# FUTURE: avoid runtime transpose by processing weights
|
||||
# prior to apply
|
||||
|
||||
def forward(self, value: Tensor, part_id: int = 0) -> Tensor:
|
||||
if part_id not in self.weight.partitions:
|
||||
return value
|
||||
|
||||
# use hadacore if possible
|
||||
if self.transforms[part_id].scheme.type == "hadamard":
|
||||
if self.transforms[part_id].scheme.head_dim is not None:
|
||||
weight_size = self.transforms[part_id].scheme.head_dim
|
||||
value = value.unflatten(-1, (-1, weight_size))
|
||||
value = ops.hadacore_transform(value)
|
||||
value = value.flatten(-2, -1)
|
||||
|
||||
return value
|
||||
|
||||
# sylvester transforms are symmetric, inv => transpose => original
|
||||
return ops.hadacore_transform(value)
|
||||
|
||||
# fall back to dense
|
||||
else:
|
||||
weight = self.weight.partitions[part_id]
|
||||
weight = (
|
||||
weight if self.transforms[part_id].args.inverse else weight.T
|
||||
) # linear := x(W.T)
|
||||
scale = self.scales[part_id]
|
||||
|
||||
if self.transforms[part_id].scheme.head_dim is not None:
|
||||
value = value.unflatten(-1, (-1, weight.size(0)))
|
||||
value = (
|
||||
dispatch_unquantized_gemm()(
|
||||
self, value.to(weight.dtype), weight, None
|
||||
).to(value.dtype)
|
||||
* scale
|
||||
)
|
||||
value = value.flatten(-2, -1)
|
||||
|
||||
return value
|
||||
|
||||
return (
|
||||
dispatch_unquantized_gemm()(
|
||||
self, value.to(weight.dtype), weight, None
|
||||
).to(value.dtype)
|
||||
* scale
|
||||
)
|
||||
|
||||
def _get_data_key(self, scheme: TransformScheme, weight_size: int) -> Hashable:
|
||||
return (id(scheme), weight_size)
|
||||
|
||||
def _get_weight_size(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
scheme: TransformScheme,
|
||||
args: TransformArgs,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
) -> int:
|
||||
if scheme.head_dim is not None:
|
||||
return scheme.head_dim
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
return input_size
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
return output_size
|
||||
|
||||
elif isinstance(layer, VocabParallelEmbedding):
|
||||
if args.location == TransformLocation.INPUT:
|
||||
return output_size
|
||||
|
||||
elif args.location == TransformLocation.OUTPUT:
|
||||
return input_size
|
||||
|
||||
raise ValueError()
|
||||
|
||||
def _validate_input_transforms(self):
|
||||
assert len(self.transforms) > 0
|
||||
location = list(self.transforms.values())[0].args.location
|
||||
|
||||
if location == TransformLocation.INPUT:
|
||||
first_data = self.weight.partitions[0].data
|
||||
for partition in self.weight.partitions.values():
|
||||
if partition.data.data_ptr() != first_data.data_ptr():
|
||||
raise ValueError("")
|
||||
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,64 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW4A4Fp4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.transform.linear import ( # noqa: E501
|
||||
CompressedTensorsLinearTransformMethod,
|
||||
TransformTuple,
|
||||
)
|
||||
|
||||
__all__ = ["is_qutlass_fp4_scheme", "QutlassNvFP4LinearMethod"]
|
||||
|
||||
|
||||
def is_qutlass_fp4_scheme(
|
||||
quant_scheme: CompressedTensorsScheme | None,
|
||||
input_tfms: dict[int, TransformTuple],
|
||||
) -> bool:
|
||||
return (
|
||||
isinstance(quant_scheme, (CompressedTensorsW4A4Fp4,))
|
||||
and len(input_tfms) == 1
|
||||
and input_tfms[0].scheme.head_dim == quant_scheme.group_size
|
||||
)
|
||||
|
||||
|
||||
class QutlassNvFP4LinearMethod(CompressedTensorsLinearTransformMethod):
|
||||
def create_weights(
|
||||
self,
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# initializes fp4 qparams
|
||||
assert isinstance(layer.scheme, (CompressedTensorsW4A4Fp4,))
|
||||
ret = super().create_weights(
|
||||
layer,
|
||||
input_size_per_partition,
|
||||
output_partition_sizes,
|
||||
input_size,
|
||||
output_size,
|
||||
params_dtype,
|
||||
**extra_weight_attrs,
|
||||
)
|
||||
|
||||
assert self.input_transform is not None
|
||||
assert len(self.input_transform.weight) == 1
|
||||
assert self.input_transform.weight[0].size(0) == layer.scheme.group_size
|
||||
|
||||
return ret
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
from compressed_tensors.transform import TransformArgs, TransformScheme
|
||||
|
||||
__all__ = ["TransformTuple"]
|
||||
|
||||
|
||||
class TransformTuple(NamedTuple):
|
||||
scheme_name: str
|
||||
scheme: TransformScheme
|
||||
args: TransformArgs
|
||||
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
def is_weak_contiguous(x: torch.Tensor):
|
||||
strides = x.stride()
|
||||
sizes = x.shape
|
||||
is_not_transpose = strides[0] == 1 and (strides[1] >= max(1, sizes[0]))
|
||||
is_transpose = strides[1] == 1 and (strides[0] >= max(1, sizes[1]))
|
||||
return is_transpose or is_not_transpose
|
||||
|
||||
|
||||
@triton.jit
|
||||
def scaled_mm_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
scale_a_ptr,
|
||||
scale_b_ptr,
|
||||
c_ptr,
|
||||
bias_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_am,
|
||||
stride_ak,
|
||||
stride_bk,
|
||||
stride_bn,
|
||||
stride_cm,
|
||||
stride_cn,
|
||||
ACCUMULATOR_DTYPE: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_A: tl.constexpr,
|
||||
BLOCK_SIZE_SCALE_B: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
||||
|
||||
pid_m = pid // num_pid_n
|
||||
pid_n = pid % num_pid_n
|
||||
|
||||
accumulator_dtype = ACCUMULATOR_DTYPE
|
||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
||||
|
||||
# NOTE: Some tensor inputs are so large, they will cause int32 overflow
|
||||
# so it is necessary to use tl.int64 for all the offsets, else SEGV will
|
||||
# eventually occur.
|
||||
|
||||
# Offsets and masks.
|
||||
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
masks_am = offsets_am < M
|
||||
|
||||
offsets_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
masks_bn = offsets_bn < N
|
||||
|
||||
offsets_k = tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
|
||||
offsets_a = stride_am * offsets_am[:, None] + stride_ak * offsets_k[None, :]
|
||||
offsets_b = stride_bk * offsets_k[:, None] + stride_bn * offsets_bn[None, :]
|
||||
|
||||
# NOTE: BLOCK_SIZE_SCALE_A could be 1 or BLOCK_SIZE_M, so need to create
|
||||
# appropriate offsets and masks for each case. Same goes for
|
||||
# BLOCK_SIZE_SCALE_B.
|
||||
offsets_scale_am = (
|
||||
tl.arange(0, BLOCK_SIZE_SCALE_A)
|
||||
+ (BLOCK_SIZE_SCALE_A > 1) * pid_m * BLOCK_SIZE_M
|
||||
)
|
||||
masks_scale_am = offsets_scale_am < M
|
||||
|
||||
offsets_scale_bn = (
|
||||
tl.arange(0, BLOCK_SIZE_SCALE_B)
|
||||
+ (BLOCK_SIZE_SCALE_B > 1) * pid_n * BLOCK_SIZE_N
|
||||
)
|
||||
masks_scale_bn = offsets_scale_bn < N
|
||||
|
||||
a_ptrs = a_ptr + offsets_a
|
||||
b_ptrs = b_ptr + offsets_b
|
||||
|
||||
scale_a_ptrs = scale_a_ptr + offsets_scale_am
|
||||
scale_b_ptrs = scale_b_ptr + offsets_scale_bn
|
||||
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
masks_k = offsets_k < K
|
||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||
a = tl.load(a_ptrs, mask=masks_a)
|
||||
|
||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||
b = tl.load(b_ptrs, mask=masks_b)
|
||||
|
||||
# Accumulate results.
|
||||
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
||||
|
||||
offsets_k += BLOCK_SIZE_K
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
# Apply scale at end.
|
||||
masks_scale_a = masks_scale_am[:, None] & (tl.arange(0, 1) < 1)[:, None]
|
||||
scale_a = tl.load(scale_a_ptrs[:, None], masks_scale_a)
|
||||
# Need to broadcast to the appropriate size, if scale_a is already
|
||||
# (BLOCK_SIZE_M, 1) then it will broadcast to its own shape. Same goes
|
||||
# for scale_b below.
|
||||
scale_a = scale_a.broadcast_to((BLOCK_SIZE_M, 1))
|
||||
accumulator = scale_a * accumulator.to(tl.float32)
|
||||
|
||||
masks_scale_b = masks_scale_bn[:, None] & (tl.arange(0, 1) < 1)[None, :]
|
||||
scale_b = tl.load(scale_b_ptrs[:, None], masks_scale_b)
|
||||
scale_b = scale_b.broadcast_to((BLOCK_SIZE_N, 1))
|
||||
accumulator = scale_b.T * accumulator.to(tl.float32)
|
||||
|
||||
# Convert to output format.
|
||||
c = accumulator.to(c_ptr.type.element_ty)
|
||||
|
||||
# Add bias, it's already in output format, so add it after conversion.
|
||||
if bias_ptr:
|
||||
offsets_bias = offsets_bn
|
||||
bias_ptrs = bias_ptr + offsets_bias
|
||||
bias_mask = offsets_bias < N
|
||||
bias = tl.load(bias_ptrs, bias_mask)
|
||||
c += bias
|
||||
|
||||
# Save output
|
||||
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
|
||||
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)
|
||||
offs_cm = offs_cm.to(tl.int64)
|
||||
offs_cn = offs_cn.to(tl.int64)
|
||||
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
||||
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
||||
|
||||
tl.store(c_ptrs, c, mask=c_mask)
|
||||
|
||||
|
||||
# input - [M, K]
|
||||
# weight - [K, N]
|
||||
def triton_scaled_mm(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: type[torch.dtype],
|
||||
bias: torch.Tensor | None = None,
|
||||
block_size_m: int = 32,
|
||||
block_size_n: int = 32,
|
||||
block_size_k: int = 32,
|
||||
use_heuristic=True,
|
||||
) -> torch.Tensor:
|
||||
M, K = input.shape
|
||||
N = weight.shape[1]
|
||||
|
||||
assert N > 0 and K > 0 and M > 0
|
||||
assert weight.shape[0] == K
|
||||
assert input.dtype == weight.dtype
|
||||
|
||||
scale_a = scale_a.reshape(-1, 1) if scale_a.dim() <= 1 else scale_a
|
||||
scale_b = scale_b.reshape(-1, 1) if scale_b.dim() <= 1 else scale_b
|
||||
|
||||
assert scale_a.dtype == scale_b.dtype and scale_a.is_floating_point()
|
||||
assert scale_a.shape[1] == 1 and (scale_a.shape[0] == 1 or scale_a.shape[0] == M)
|
||||
assert scale_b.shape[1] == 1 and (scale_b.shape[0] == 1 or scale_b.shape[0] == N)
|
||||
assert out_dtype.is_floating_point
|
||||
assert bias is None or bias.is_floating_point()
|
||||
assert is_weak_contiguous(input)
|
||||
assert is_weak_contiguous(weight)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
result = torch.empty((M, N), dtype=out_dtype, device=input.device)
|
||||
|
||||
has_scalar = lambda x: x.shape[0] == 1 and x.shape[1] == 1
|
||||
|
||||
if use_heuristic:
|
||||
is_small_N = N < 8192
|
||||
next_power_of_2_M = max(32, triton.next_power_of_2(M))
|
||||
if next_power_of_2_M <= 32:
|
||||
tile_shape = (64, 64, 256) if is_small_N else (64, 128, 256)
|
||||
elif next_power_of_2_M <= 64:
|
||||
tile_shape = (64, 64, 256)
|
||||
elif next_power_of_2_M <= 128:
|
||||
tile_shape = (64, 128, 128)
|
||||
else:
|
||||
tile_shape = (128, 128, 128)
|
||||
|
||||
block_size_m, block_size_n, block_size_k = tile_shape
|
||||
|
||||
block_size_sa = 1 if has_scalar(scale_a) else block_size_m
|
||||
block_size_sb = 1 if has_scalar(scale_b) else block_size_n
|
||||
|
||||
accumulator_dtype = tl.float32 if input.is_floating_point() else tl.int32
|
||||
|
||||
# A = input, B = weight, C = result
|
||||
# A = M x K, B = K x N, C = M x N
|
||||
scaled_mm_kernel[grid](
|
||||
input,
|
||||
weight,
|
||||
scale_a,
|
||||
scale_b,
|
||||
result,
|
||||
bias,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
input.stride(0),
|
||||
input.stride(1),
|
||||
weight.stride(0),
|
||||
weight.stride(1),
|
||||
result.stride(0),
|
||||
result.stride(1),
|
||||
accumulator_dtype,
|
||||
BLOCK_SIZE_M=block_size_m,
|
||||
BLOCK_SIZE_N=block_size_n,
|
||||
BLOCK_SIZE_K=block_size_k,
|
||||
BLOCK_SIZE_SCALE_A=block_size_sa,
|
||||
BLOCK_SIZE_SCALE_B=block_size_sb,
|
||||
)
|
||||
|
||||
return result.to(out_dtype)
|
||||
216
model_executor/layers/quantization/compressed_tensors/utils.py
Normal file
216
model_executor/layers/quantization/compressed_tensors/utils.py
Normal file
@@ -0,0 +1,216 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
|
||||
import regex as re
|
||||
from compressed_tensors import CompressionFormat
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
def is_activation_quantization_format(format: str) -> bool:
|
||||
_ACTIVATION_QUANTIZATION_FORMATS = [
|
||||
CompressionFormat.naive_quantized.value,
|
||||
CompressionFormat.int_quantized.value,
|
||||
CompressionFormat.float_quantized.value,
|
||||
CompressionFormat.nvfp4_pack_quantized.value,
|
||||
]
|
||||
return format in _ACTIVATION_QUANTIZATION_FORMATS
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: str | None,
|
||||
ignore: Iterable[str] = tuple(),
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping and layer_name not in ignore:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore
|
||||
)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(
|
||||
f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme."
|
||||
)
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(
|
||||
layer_name=layer_name, targets=ignore
|
||||
)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str, targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
return any(_is_equal_or_regex_match(layer_name, target) for target in targets)
|
||||
|
||||
|
||||
def find_matched_target(
|
||||
layer_name: str | None,
|
||||
module: Module,
|
||||
targets: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
) -> str:
|
||||
"""
|
||||
Helper function to look up which "target" in the compressed-tensors
|
||||
config that a layer corresponds to.
|
||||
|
||||
Recall that a compressed-tensors configs has a concept of
|
||||
config_groups, where each layer can be quantized with a different
|
||||
scheme.
|
||||
|
||||
targets in each config_group will be a list of either layer names
|
||||
(or regexes corresponding to layer names) or names of torch Modules.
|
||||
|
||||
First, we try to match the layer_name with a target
|
||||
Second, we try to match the module's name with a target
|
||||
Third, we try to map the layer_name to a list of fused module names.
|
||||
*All* component module names must match in order for a match to be
|
||||
successful. A successful match returns the first component target
|
||||
|
||||
:param layer_name: layer name
|
||||
:param module: torch.nn.Module
|
||||
:param targets: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
:param fused_strategy: either "all" or "any". If using "all", fused
|
||||
layers match if "all" of its components match
|
||||
"""
|
||||
|
||||
if layer_name is None:
|
||||
layer_name = ""
|
||||
|
||||
matched_target = (
|
||||
_find_first_match(layer_name, targets)
|
||||
or _find_first_match(module.__class__.__name__, targets, True)
|
||||
or _match_fused_layer(layer_name, targets, fused_mapping)
|
||||
)
|
||||
|
||||
if matched_target is None:
|
||||
raise ValueError(
|
||||
f"Unable to find matching target for {layer_name} in the "
|
||||
"compressed-tensors config."
|
||||
)
|
||||
|
||||
return matched_target
|
||||
|
||||
|
||||
def _find_first_match(
|
||||
value: str, targets: Iterable[str], check_contains: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
Returns first element of target that matches value either
|
||||
exactly or as a regex after 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
|
||||
:param value: string to compare the list of targets against
|
||||
:param targets: list of targets to match the layer against
|
||||
:param check_contains: whether or not to do a substring match
|
||||
"""
|
||||
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(value, target, check_contains=check_contains):
|
||||
return target
|
||||
return None
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(
|
||||
value: str, target: str, check_contains: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _match_fused_layer(
|
||||
layer_name: str,
|
||||
target_layers: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]],
|
||||
) -> str | None:
|
||||
"""
|
||||
Match a fused layer name to its corresponding individual layer in
|
||||
target_layers. Returns first value in fused_mapping which matches targets
|
||||
|
||||
Implements an "all" matching strategy where a fused layer matches iff
|
||||
"all" of its components match
|
||||
|
||||
:param layer_name: layer name
|
||||
:param target_layers: list of targets to match the layer against
|
||||
:param fused_mapping: map from fused layer names to its components
|
||||
|
||||
Examples:
|
||||
layer_name = "model.layers.0.self_attn.qkv_proj"
|
||||
target_layers = ["model.layers.0.self_attn.q_proj",
|
||||
"model.layers.0.self_attn.k_proj",
|
||||
"model.layers.0.self_attn.v_proj"]
|
||||
"""
|
||||
# find layer_name in mapping
|
||||
fused = next((key for key in fused_mapping if layer_name.endswith(key)), None)
|
||||
if fused is None:
|
||||
return None
|
||||
|
||||
# expand path of unfused components
|
||||
unfused_paths = [
|
||||
layer_name.replace(fused, unfused) for unfused in fused_mapping[fused]
|
||||
]
|
||||
|
||||
# for each unfused component, find a match in targets
|
||||
unfused_matches: list[str | None] = []
|
||||
for unfused in unfused_paths:
|
||||
for target in target_layers:
|
||||
if _is_equal_or_regex_match(unfused, target):
|
||||
unfused_matches.append(target)
|
||||
break
|
||||
else:
|
||||
unfused_matches.append(None)
|
||||
|
||||
return unfused_matches[0] if all(unfused_matches) else None
|
||||
Reference in New Issue
Block a user