# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import ( QuantizationConfig, QuantizationMethods, ) from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none", "dynamic"] class Int8TpuConfig(QuantizationConfig): """Int8 Quantization Config class for TPU Backend.""" def __init__( self, activation_scheme: str = "none", ) -> None: super().__init__() if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme def get_name(self) -> QuantizationMethods: return "tpu_int8" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: raise NotImplementedError("This function should not be called with TPU Backend") @staticmethod def get_config_filenames() -> list[str]: return [] @classmethod def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig": activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) return cls(activation_scheme=activation_scheme) def get_quant_method( self, layer: Module, prefix: str ) -> Optional["TPUInt8LinearMethod"]: if isinstance(layer, LinearBase): return TPUInt8LinearMethod(self) return None class TPUInt8LinearMethod(LinearMethodBase): """Int8 Linear method for TPU Quant.""" def __init__(self, quant_config: Int8TpuConfig): self.quant_config = quant_config self.quantize_activation = False if self.quant_config.activation_scheme == "dynamic": self.quantize_activation = True def create_weights( self, layer: Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): weight_loader = extra_weight_attrs.get("weight_loader") weight = ModelWeightParameter( data=torch.empty( sum(output_partition_sizes), input_size_per_partition, dtype=params_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) def _quantize_weight( self, weight: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: weight_dtype = weight.dtype weight = weight.cpu().to(torch.float32) n_bit = 8 eps = 1e-5 max_int = 2 ** (n_bit - 1) - 1 min_int = -(2 ** (n_bit - 1)) max_val = weight.abs().amax(dim=-1, keepdim=True) max_val = max_val.clamp(min=eps) qscale = max_val / max_int qweight = torch.clamp( torch.round(weight * (1.0 / qscale)), min_int, max_int ).to(torch.int8) qscale = qscale.squeeze().to(weight_dtype) return qweight, qscale def process_weights_after_loading(self, layer: Module) -> None: layer.weight = Parameter(layer.weight.data, requires_grad=False) device = layer.weight.device qweight, qscale = self._quantize_weight(layer.weight) qweight = qweight.to(device) qscale = qscale.to(device) layer.weight = Parameter(qweight, requires_grad=False) layer.scale = Parameter(qscale, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: try: import torch_xla.experimental.custom_kernel # noqa: F401 except ImportError as err: raise ImportError( "Please install torch_xla by following the instructions at " "https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501 "to run vLLM on TPU." ) from err weight = layer.weight scale = layer.scale out = torch.ops.xla.quantized_matmul_int8( x, weight, scale, quantize_activation=self.quantize_activation ) if bias is not None: out = out + bias return out