Sync from v0.13

This commit is contained in:
2026-01-19 10:38:50 +08:00
parent b2ef04d792
commit 5aef6c175a
3714 changed files with 854317 additions and 89342 deletions

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
@dataclass
class ScaledMMLinearLayerConfig:
is_channelwise: bool
is_static_input_scheme: bool
input_symmetric: bool
class ScaledMMLinearKernel(ABC):
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self,
c: ScaledMMLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
i_s_param_name: str,
i_zp_param_name: str,
azp_adj_param_name: str,
) -> None:
assert self.can_implement(c)
assert self.is_supported()
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.i_s_name = i_s_param_name
self.i_zp_name = i_zp_param_name
self.azp_adj_name = azp_adj_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def _get_weight_params(
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.i_s_name),
getattr(layer, self.i_zp_name),
getattr(layer, self.azp_adj_name),
)

View File

@@ -0,0 +1,81 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from vllm.model_executor.layers.quantization.kernels.scaled_mm.aiter import (
AiterScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cpu import (
CPUScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import (
CutlassScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
}
def choose_scaled_mm_linear_kernel(
config: ScaledMMLinearLayerConfig, compute_capability: int | None = None
) -> type[ScaledMMLinearKernel]:
"""
Choose an ScaledMMLinearKernel that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (ScaledMMLinearLayerConfig): Description of the linear layer
to be implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[ScaledMMLinearKernel]: Chosen kernel.
"""
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
failure_reasons.append(f"{kernel.__name__}: disabled by env var")
continue
# If the current platform uses compute_capability,
# make sure the kernel supports the compute capability.
is_supported, reason = kernel.is_supported(compute_capability)
if not is_supported:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
can_implement, reason = kernel.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel.__name__}: {reason}")
continue
return kernel
raise ValueError(
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reasons)
)

View File

@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from .cutlass import CutlassScaledMMLinearKernel
from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig
class AiterScaledMMLinearKernel(CutlassScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "currently supported on non-ROCm platform.",
)
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 90:
return False, f"requires capability 90, got {compute_capability}"
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return (
False,
"AiterScaledMMLinearKernel requires `aiter` which is not "
+ "installed on ROCm.",
)
if not rocm_aiter_ops.is_linear_enabled():
return (
False,
"AiterScaledMMLinearKernel is disabled. "
+ "Enable by setting `VLLM_ROCM_USE_AITER=1` "
+ "and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
+ "`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return (
False,
"AiterScaledMMLinearKernel only supports symmetric " + "quantization.",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
super().process_weights_after_loading(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
`AiterScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
)
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
assert x_zp is None, (
"AiterScaledMMLinearKernel only supports symmetric quantization."
)
out_dtype = x.dtype
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
m = x_q.shape[0] # a
n = w_q.shape[1] # b
per_tensor_scale_a = x_s.numel() == 1
per_tensor_scale_b = w_s.numel() == 1
per_token_scale_a = x_s.numel() == m
per_channel_scale_b = w_s.numel() == n
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert (per_tensor_scale_a and per_tensor_scale_b) or (
per_token_scale_a and per_channel_scale_b
), (
"Currently only support per-tensor-per-tensor GEMM "
+ " and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterScaledMMLinearKernel` "
+ "does not support AITER block scaled GEMM."
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@@ -0,0 +1,220 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class CPUScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Requires CPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
dtype = weight.dtype
N, K = weight.size()
if (
current_platform.get_cpu_architecture() == CpuArchEnum.X86
and envs.VLLM_CPU_SGL_KERNEL
and self.config.input_symmetric
and check_cpu_sgl_kernel(N, K, dtype)
):
self.linear_method = self._apply_weights_sgl
self.process_weights_for_sgl(layer)
else:
self.linear_method = self._apply_weights_onednn
self.process_weights_for_onednn(layer)
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Transpose to [K, N] for convenience
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# oneDNN kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
# s_a * s_b * [(A - zp_a)B] + bias =
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, self.w_q_name)
weight_scale = getattr(layer, self.w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze()
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
weight = getattr(layer, self.w_q_name)
self.dnnl_handler = ops.create_onednn_scaled_mm(
weight,
getattr(layer, self.w_s_name),
torch.get_default_dtype(),
getattr(layer, self.i_s_name) is None,
not self.config.input_symmetric,
32,
)
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
setattr(layer, self.w_q_name, None)
del weight
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
# WEIGHT
weight = getattr(layer, self.w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
)
if layer.bias is not None:
bias = layer.bias
layer.register_parameter(
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
)
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, self.w_s_name)
if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.linear_method(
layer,
x,
bias,
)
def _apply_weights_onednn(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
x, i_s, i_zp, self.config.input_symmetric
)
m = x.size(0)
n = self.dnnl_handler.n
out = torch.empty((m, n), dtype=x.dtype)
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
return out
def _apply_weights_sgl(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant(
x,
w_q,
w_s,
layer.bias_fp32 if bias is not None else None,
x.dtype,
True,
)

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "Requires CUDA."
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc.major * 10 + _cc.minor
if compute_capability is not None and compute_capability < 75:
return False, f"requires capability 75, got {compute_capability}"
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
input_zero_point = getattr(layer, self.i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, self.i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, self.i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not self.config.input_symmetric:
weight = getattr(layer, self.w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if self.config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, self.i_zp_name) * azp_adj
setattr(
layer,
self.azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
else:
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=symmetric
)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
static = i_zp is not None
azp = None if static else x_zp
return ops.cutlass_scaled_mm_azp(
x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
return ops.cutlass_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "Requires ROCm or CUDA."
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return False, "Only symmetric input is supported."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer,
self.w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name)
replace_parameter(
layer,
self.i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, self.i_zp_name, None)
else:
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer)
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True
)
assert x_zp is None, "Triton kernel only supports symmetric quantization"
return triton_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)

View File

@@ -0,0 +1,106 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
import torch
from functorch.experimental.control_flow import cond # noqa: F401
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."
if c.is_static_input_scheme:
return False, "ScaledMMXLA requires dynamic activation scales."
if not c.input_symmetric:
return False, "ScaledMMXLA requires symmetric activation scales."
if not c.is_channelwise:
return False, "ScaledMMXLA requires channelwise weight scales"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
)
# WEIGHT SCALE
# XLA kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
# [out_channel,] (different than cutlass_scaled_mm)
weight_scale = weight_scale.squeeze(-1)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# Only support symmetric dynamic activation quantization.
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
# Filter warning for cond usage in apply_weights. It is okay
# to specialize the graph since bias is not dynamic.
warnings.filterwarnings(
"ignore",
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
)
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x + bias
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])