First commit

This commit is contained in:
2025-08-05 19:02:46 +08:00
parent 9efe891f99
commit 99fb9f5cb0
1412 changed files with 203615 additions and 0 deletions

View File

@@ -0,0 +1,412 @@
from typing import Any, Dict, List, Optional, cast
import torch
from pydantic import BaseModel
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
CompressedTensorsMoEMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsW8A16Fp8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
__all__ = ["CompressedTensorsLinearMethod"]
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None):
self.ignore = ignore
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
return "compressed_tensors"
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(prefix, ignore=self.ignore):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
scheme = self.get_scheme(layer=layer, layer_name=prefix)
layer.scheme = scheme
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
if isinstance(layer, FusedMoE):
return CompressedTensorsMoEMethod.get_moe_method(self)
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
target_scheme_map: Dict[str, Any] = dict()
ignore = cast(List[str], config.get("ignore"))
quant_format = cast(str, config.get("format"))
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
try:
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
target_scheme_map[target]["input_activations"] = None
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
def _check_scheme_supported(self,
min_capability: int,
error: bool = True) -> bool:
capability_tuple = current_platform.get_device_capability()
if capability_tuple is not None:
capability = capability_tuple.to_int()
supported = capability >= min_capability
if error and not supported:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
return supported
else:
return False
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_tensor = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TENSOR.value)
is_static = not weight_quant.dynamic and not input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
def _is_dynamic_token_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
weight_strategy = (
weight_quant.strategy == QuantizationStrategy.TENSOR.value
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value)
is_token = (weight_strategy and input_quant.strategy
== QuantizationStrategy.TOKEN.value)
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
# Both symmetric and asymmetric input quantization supported.
# Only symmetric weight quantization supported.
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
def _is_fp8_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights and activations quantized.
if weight_quant is None or input_quant is None:
return False
# Confirm weight scheme is supported.
is_floating_point = (weight_quant.type == QuantizationType.FLOAT
and input_quant.type == QuantizationType.FLOAT)
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_floating_point and is_symmetric_weight and is_static_weight
and is_per_tensor_or_channel_weight):
return False
# Dynamic quantization is always supported if weights supported.
if input_quant.dynamic:
return True
# Confirm activation scheme is supported.
is_symmetric_activation = input_quant.symmetric
is_per_tensor_activation = (
input_quant.strategy == QuantizationStrategy.TENSOR)
return is_symmetric_activation and is_per_tensor_activation
def _is_fp8_w8a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
# Confirm weights quantized.
if weight_quant is None:
return False
# Confirm we have floating points.
if weight_quant.type != QuantizationType.FLOAT:
return False
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight # noqa: SIM103
and is_per_tensor_or_channel_weight):
return False
# All conditions satisfied.
return True
def _is_wNa16_group_channel(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
is_symmetric = weight_quant.symmetric
is_channel_group = (
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
or weight_quant.strategy == QuantizationStrategy.GROUP.value)
is_static = not weight_quant.dynamic
return (is_channel_group and input_quant_none and is_symmetric
and is_static)
def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(
strategy=weight_quant.strategy,
num_bits=weight_quant.num_bits,
group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.pack_quantized.value
and weight_quant.num_bits in WNA16_SUPPORTED_BITS):
return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
# Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
is_fp8_w8a8_supported = self._check_scheme_supported(
CompressedTensorsW8A8Fp8.get_min_capability(), error=False)
if is_fp8_w8a8_supported:
return CompressedTensorsW8A8Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
else:
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
if self._is_fp8_w8a16(weight_quant, input_quant):
return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy,
is_static_input_scheme=(input_quant
and not input_quant.dynamic))
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=True,
input_symmetric=input_quant.symmetric)
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
strategy=weight_quant.strategy,
is_static_input_scheme=False,
input_symmetric=input_quant.symmetric)
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
We first check whether a layer is in the ignore group and use
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
# need to make accelerate optional in ct to do this
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.target_scheme_map.keys())
# Find the quant_scheme
scheme_dict = self.target_scheme_map[matched_target]
scheme = self._get_scheme_from_parts(
weight_quant=scheme_dict["weights"],
input_quant=scheme_dict["input_activations"])
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return scheme
class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer.scheme.create_weights(
layer=layer,
input_size=input_size,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input. See LinearMethodBase for param details
"""
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")

View File

@@ -0,0 +1,511 @@
import enum
from enum import Enum
from typing import Callable, List, Optional
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
WNA16_SUPPORTED_BITS)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip, print_warning_once
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
__all__ = [
"CompressedTensorsMoEMethod", "CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsWNA16MoEMethod"
]
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
@staticmethod
def get_moe_method(
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
) -> "CompressedTensorsMoEMethod":
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
input_quant = quant_config.target_scheme_map["Linear"].get(
"input_activations")
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
return CompressedTensorsWNA16MoEMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
raise ValueError(
"For FP8 Fused MoE layers, only per-tensor scales"
"for weights and activations are supported. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.float8_e4m3fn
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
2 * intermediate_size,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size,
intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
2,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add the quantization method used (per tensor/grouped/channel)
# to ensure the weight scales are loaded in properly
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
if self.static_input_scales:
w13_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_input_scale", w13_input_scale)
set_weight_attrs(w13_input_scale, extra_weight_attrs)
w2_input_scale = torch.nn.Parameter(torch.ones(
num_experts, dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_input_scale", w2_input_scale)
set_weight_attrs(w2_input_scale, extra_weight_attrs)
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ.
if self.static_input_scales:
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
raise ValueError(
"QuantConfig has static quantization, but found "
"activation scales are None.")
if (not all_close_1d(layer.w13_input_scale)
or not all_close_1d(layer.w2_input_scale)):
print_warning_once(
"Found input_scales that are not equal for "
"fp8 MoE layer. Using the maximum across experts "
"for each layer. ")
layer.w13_input_scale = torch.nn.Parameter(
layer.w13_input_scale.max(), requires_grad=False)
layer.w2_input_scale = torch.nn.Parameter(
layer.w2_input_scale.max(), requires_grad=False)
# If rocm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_weight_scale,
layer.w13_input_scale)
w2_weight, w2_weight_scale, w2_input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_weight_scale,
layer.w2_input_scale)
# Reset the parameter
layer.w13_weight = torch.nn.Parameter(w13_weight,
requires_grad=False)
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
requires_grad=False)
if w13_input_scale is not None:
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight,
requires_grad=False)
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
requires_grad=False)
if w2_input_scale is not None:
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_weight_scale is not None
shard_size = layer.intermediate_size_per_partition
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
for expert_id in range(layer.num_experts):
start = 0
for shard_id in range(2):
dq_weight = per_tensor_dequantize(
layer.w13_weight[expert_id][start:start + shard_size, :],
layer.w13_weight_scale[expert_id][shard_id])
layer.w13_weight[expert_id][
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
dq_weight, max_w13_scales[expert_id])
start += shard_size
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_experts(x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
# TODO: @dsikka: refactor this to use schemes as other kernels
# are supported + check if the layer is being ignored.
config = self.quant_config.target_scheme_map["Linear"].get("weights")
self.num_bits = config.num_bits
self.packed_factor = 32 // config.num_bits
self.strategy = config.strategy.value
self.group_size = config.group_size
assert config.symmetric, (
"Only symmetric quantization is supported for MoE")
if not (self.quant_config.quant_format
== CompressionFormat.pack_quantized.value
and self.num_bits in WNA16_SUPPORTED_BITS):
raise ValueError("For Fused MoE layers, only ",
f"{CompressionFormat.pack_quantized.value} ",
"is supported for the following bits: ",
f"{WNA16_SUPPORTED_BITS}")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size: int,
params_dtype: torch.dtype, **extra_weight_attrs):
# Will transpose the loaded weight along the
# intermediate and hidden dim sizes. Will
# shard for TP along the transposed dims
extra_weight_attrs.update({
"is_transposed": True,
"quant_method": self.strategy
})
w13_weight = torch.nn.Parameter(torch.empty(num_experts,
hidden_size //
self.packed_factor,
2 * intermediate_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w13_weight_packed", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(num_experts,
intermediate_size //
self.packed_factor,
hidden_size,
dtype=torch.int32),
requires_grad=False)
layer.register_parameter("w2_weight_packed", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.strategy == "channel":
num_groups_w2 = num_groups_w13 = 1
self.group_size = -1
else:
num_groups_w2 = intermediate_size // self.group_size
num_groups_w13 = hidden_size // self.group_size
w13_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w13,
2 * intermediate_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_scale)
set_weight_attrs(w13_scale, extra_weight_attrs)
w2_scale = torch.nn.Parameter(torch.ones(num_experts,
num_groups_w2,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_scale)
set_weight_attrs(w2_scale, extra_weight_attrs)
w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w2_weight_shape", w2_weight_shape)
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2),
requires_grad=False)
layer.register_parameter("w13_weight_shape", w13_weight_shape)
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
w13_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx", w13_g_idx)
set_weight_attrs(w13_g_idx, extra_weight_attrs)
w2_g_idx = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx", w2_g_idx)
set_weight_attrs(w2_g_idx, extra_weight_attrs)
w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
hidden_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty(
num_experts,
intermediate_size,
dtype=torch.int32,
),
requires_grad=False,
)
layer.register_parameter("w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
layer.a13_scale = None
layer.a2_scale = None
layer.marlin_state = GPTQMarlinState.REPACK
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:,
scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_moe_permute_scales(s: torch.Tensor, size_k: int,
size_n: int, group_size: int,
num_bits: int):
num_experts = s.shape[0]
output = torch.empty((num_experts, s.shape[1], s.shape[2]),
device=s.device,
dtype=s.dtype)
for e in range(num_experts):
output[e] = marlin_permute_scales(s[e], size_k, size_n,
group_size, num_bits)
return output
size_k2 = layer.w2_weight_packed.shape[2]
size_k13 = layer.w13_weight_packed.shape[2]
num_experts = layer.w13_g_idx.shape[0]
device = layer.w13_g_idx.device
layer.w13_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
requires_grad=False,
)
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
layer.w13_weight_packed,
layer.w13_g_idx_sort_indices,
layer.w13_weight_packed.shape[1] * self.packed_factor,
layer.w13_weight_packed.shape[2],
self.num_bits,
)
replace_tensor("w13_weight_packed", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_weight_packed,
layer.w2_g_idx_sort_indices,
layer.w2_weight_packed.shape[1] * self.packed_factor,
layer.w2_weight_packed.shape[2],
self.num_bits,
)
replace_tensor("w2_weight_packed", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
layer.w13_weight_scale,
size_k13,
layer.w13_weight_scale.shape[2],
self.group_size,
self.num_bits,
)
replace_tensor("w13_weight_scale", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
layer.w2_weight_scale,
layer.w2_weight_scale.shape[1] * self.packed_factor,
size_k2,
self.group_size,
self.num_bits,
)
replace_tensor("w2_weight_scale", marlin_w2_scales)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool = True,
use_grouped_topk: bool = False,
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
fused_marlin_moe)
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function)
return fused_marlin_moe(
x,
layer.w13_weight_packed,
layer.w2_weight_packed,
layer.w13_weight_scale,
layer.w2_weight_scale,
router_logits,
topk_weights,
topk_ids,
g_idx1=layer.w13_g_idx,
g_idx2=layer.w2_g_idx,
sort_indices1=layer.w13_g_idx_sort_indices,
sort_indices2=layer.w2_g_idx_sort_indices,
num_bits=self.num_bits,
)

View File

@@ -0,0 +1,19 @@
from .compressed_tensors_scheme import CompressedTensorsScheme
from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS,
CompressedTensorsW4A16Sparse24)
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
from .compressed_tensors_wNa16 import (WNA16_SUPPORTED_BITS,
CompressedTensorsWNA16)
__all__ = [
"CompressedTensorsScheme",
"CompressedTensorsWNA16",
"CompressedTensorsW8A16Fp8",
"CompressedTensorsW4A16Sparse24",
"CompressedTensorsW8A8Int8",
"CompressedTensorsW8A8Fp8",
"WNA16_SUPPORTED_BITS",
"W4A16SPARSE24_SUPPORTED_BITS",
]

View File

@@ -0,0 +1,52 @@
from abc import ABC, abstractmethod
from typing import Optional
import torch
__all__ = ["CompressedTensorsScheme"]
class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: torch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
:param bias: bias parameter
"""
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module):
"""
Called after weight loading is complete for any cleanup that
needs to occur.
"""
raise NotImplementedError

View File

@@ -0,0 +1,153 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter)
from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsW4A16Sparse24"]
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
}
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.strategy = strategy
self.group_size = group_size
self.tile_size = 16
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}")
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
if self.strategy == "group" and self.group_size is None:
raise ValueError(
"group_size must be given when using strategy group")
@classmethod
def get_min_capability(cls) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile to be torch.nn.Parameter
layer.weight_packed = Parameter(layer.weight_packed.data,
requires_grad=False)
layer.scale_packed = Parameter(layer.scale_packed.data,
requires_grad=False)
layer.meta = Parameter(layer.meta.data, requires_grad=False)
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
pack_factor = 32 // self.quant_type.size_bits
output_size_per_partition = sum(output_partition_sizes)
qweight = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // self.tile_size // 2,
output_size_per_partition * self.tile_size // pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=pack_factor,
marlin_tile_size=self.tile_size,
weight_loader=weight_loader)
input_groups = (1 if self.group_size is None else
input_size_per_partition // self.group_size)
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if self.group_size is not None:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
else:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
meta = PackedvLLMParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)
layer.register_parameter("weight_packed", qweight)
layer.register_parameter("weight_shape", weight_shape)
layer.register_parameter("scale_packed", scales)
layer.register_parameter("meta", meta)
max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_24_MIN_THREAD_N) * GPTQ_MARLIN_24_MAX_PARALLEL
workspace = Parameter(torch.zeros(max_workspace_size, dtype=torch.int),
requires_grad=False)
layer.workspace = workspace
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
qweight = layer.weight_packed
meta = layer.meta
scales = layer.scale_packed
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace, self.quant_type, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,118 @@
from typing import Callable, List, Optional
import torch
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
__all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [
QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR
]
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
layer.logical_widths)
layer.weight_scale = torch.nn.Parameter(ws_channelwise,
requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
requires_grad=False)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(),
requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(layer.input_scale.data,
requires_grad=False)
prepare_fp8_layer_for_marlin(layer, strategy="channel")
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}")
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints)
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_marlin_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)

View File

@@ -0,0 +1,143 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
from vllm.utils import is_hip
__all__ = ["CompressedTensorsW8A8Fp8"]
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
@classmethod
def get_min_capability(cls) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
if is_hip():
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=max_w_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
if is_hip():
weight, weight_scale, input_scale = \
normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale)
if input_scale is not None:
layer.input_scale = Parameter(input_scale,
requires_grad=False)
else:
weight_scale = layer.weight_scale.data
layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
# WEIGHT
weight = ModelWeightParameter(data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
# TODO: update create_xxx_parameter functions to return
# the newly added parameters
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
# min requirement for fp8 kernels
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)

View File

@@ -0,0 +1,155 @@
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_int8_linear, convert_to_channelwise)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter)
logger = init_logger(__name__)
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool,
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
@classmethod
def get_min_capability(cls) -> int:
# turing and up
return 75
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(self.logical_widths) > 1
if is_fused_module and self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(layer.weight_scale,
self.logical_widths)
layer.weight_scale = Parameter(ws_channelwise, requires_grad=False)
else:
layer.weight_scale = Parameter(layer.weight_scale.data,
requires_grad=False)
# INPUT SCALE
if self.is_static_input_scheme:
if self.input_symmetric:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_zero_point = None
else:
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = layer.input_zero_point.to(dtype=torch.int32)
range_max = (layer.input_scale *
(int8_traits.max - azps)).max()
range_min = (layer.input_scale *
(int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max -
int8_traits.min)
layer.input_scale = Parameter(scale, requires_grad=False)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min -
range_min / scale).to(dtype=torch.int32)
layer.input_zero_point = Parameter(azp, requires_grad=False)
else:
layer.input_scale = None
layer.input_zero_point = None
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
if not self.input_symmetric:
layer.azp_adj = layer.weight.sum(dim=0,
keepdim=True,
dtype=torch.int32)
else:
layer.azp_adj = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
self.logical_widths = output_partition_sizes
# WEIGHT
if input_size_per_partition % 64 != 0:
pad_input_size_per_partition = (input_size_per_partition // 64 + 1) * 64
else:
pad_input_size_per_partition = input_size_per_partition
w_pad = torch.zeros(
sum(output_partition_sizes),
pad_input_size_per_partition,
dtype=torch.int8)
w = w_pad[:, :input_size_per_partition]
weight = ModelWeightParameter(data=w,
input_dim=1,
output_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight", weight)
# WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = ChannelQuantScaleParameter(
data=torch.empty((sum(output_partition_sizes), 1),
dtype=torch.float32),
output_dim=0,
weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = PerTensorScaleParameter(data=torch.empty(
len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if self.is_static_input_scheme:
input_scale = BasevLLMParameter(data=torch.empty(
1, dtype=torch.float32),
weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
if not self.input_symmetric:
# Note: compressed-tensors stores the zp using the same dtype
# as the weights
# AZP loaded as int8 but used as int32
input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8),
weight_loader=weight_loader)
layer.register_parameter("input_zero_point", input_zero_point)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_int8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
input_zero_point=layer.input_zero_point,
azp_adj=layer.azp_adj,
bias=bias)

View File

@@ -0,0 +1,163 @@
from typing import Callable, List, Optional, Set
import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks)
from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128
}
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits
self.strategy = strategy
self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or "
"channelwise quantization, but found no group "
"size and strategy is not channelwise.")
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
raise ValueError(
f"Unsupported num_bits = {num_bits}. "
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
@classmethod
def get_min_capability(cls) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, output_size: int,
input_size: int, output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=False,
has_g_idx=self.has_g_idx
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsWNA16",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel)
scales_and_zp_size = input_size // group_size
if partition_scales:
assert input_size_per_partition % group_size == 0
scales_and_zp_size = input_size_per_partition // group_size
weight = PackedvLLMParameter(input_dim=1,
output_dim=0,
weight_loader=weight_loader,
packed_factor=self.pack_factor,
packed_dim=1,
data=torch.empty(
output_size_per_partition,
input_size_per_partition //
self.pack_factor,
dtype=torch.int32,
))
weight_scale_args = {
"weight_loader":
weight_loader,
"data":
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
)
}
if not partition_scales:
weight_scale = ChannelQuantScaleParameter(output_dim=0,
**weight_scale_args)
else:
weight_scale = GroupQuantScaleParameter(output_dim=0,
input_dim=1,
**weight_scale_args)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = BasevLLMParameter(data=torch.empty(2,
dtype=torch.int64),
weight_loader=weight_loader)
layer.register_parameter("weight_packed", weight)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="weight_packed",
w_s_param_name="weight_scale",
w_zp_param_name=None,
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is
# different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)

View File

@@ -0,0 +1,269 @@
import re
from enum import Enum
from typing import Any, Dict, Iterable, Optional, Union
from pydantic import BaseModel, Field, field_validator
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
FUSED_LAYER_NAME_MAPPING)
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
naive_quantized = "naive-quantized"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
marlin_24 = "marlin-24"
class QuantizationType(str, Enum):
"""
Enum storing quantization type options
"""
INT = "int"
FLOAT = "float"
class QuantizationStrategy(str, Enum):
"""
Enum storing quantization strategy options
"""
TENSOR = "tensor"
CHANNEL = "channel"
GROUP = "group"
BLOCK = "block"
TOKEN = "token"
class ActivationOrdering(str, Enum):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
"""
GROUP = "group"
WEIGHT = "weight"
class QuantizationArgs(BaseModel):
"""
User facing arguments used to define a quantization config
for weights or activations
:param num_bits: quantization bit depth
:param type: dtype to quantized to, either int or float
:param symmetric: whether or not quantization scale is symmetric
:param strategy: string determining the scope of scale/zero-point to apply
:param group_size: group length to use for the group strategy
:param block_structure: 2d block structure to use for the block
strategy, must be of the format "2x4", "8x16", etc.
:param dynamic: set True to perform dynamic quantization -
values will not be calibrated during calibration phase,
instead during inference new quantization ranges will be
observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
"""
num_bits: int = 8
type: QuantizationType = QuantizationType.INT
symmetric: bool = True
group_size: Optional[int] = None
strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None
dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field(
default="minmax",
description=("The class to use to compute the quantization param - "
"scale and zero-point'"),
)
observer_kwargs: Dict[str, Any] = Field(
default_factory=dict,
description=
("optional dict of kwargs to be passed directly to torch quantization "
"Observers constructor excluding quantization range or symmetry"),
)
@field_validator("actorder", mode="before")
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
if isinstance(value, bool):
return ActivationOrdering.GROUP if value else None
if isinstance(value, str):
return ActivationOrdering(value.lower())
return value
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in FUSED_LAYER_NAME_MAPPING:
shard_proj_names = FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
"""
if layer_name is None:
layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True))
if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
"""
Returns first element of target that matches value either
exactly or as a regex after 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
:param value: string to compare the list of targets against
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
for target in targets:
if _is_equal_or_regex_match(value,
target,
check_contains=check_contains):
return target
return None
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False