Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
0
vllm/model_executor/kernels/__init__.py
Normal file
0
vllm/model_executor/kernels/__init__.py
Normal file
397
vllm/model_executor/kernels/linear/__init__.py
Normal file
397
vllm/model_executor/kernels/linear/__init__.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
This module re-exports linear kernel implementations to provide a
|
||||
stable import interface during an ongoing reorganization. Upcoming
|
||||
PRs will remove the scaled_mm and mixed_precision subdirectories
|
||||
and reorganize kernels by provider (aiter, cutlass, flashinfer, etc.)
|
||||
rather than by precision type. By centralizing exports here, we
|
||||
minimize the need to update imports across other modules when the
|
||||
internal structure changes. If you are adding a new kernel selector
|
||||
or kernel implementation, add it to this __init__.py to maintain
|
||||
import stability.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear.mixed_precision import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [AiterInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUwNa16LinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
Dynamic4bitLinearKernel,
|
||||
CPUWNA16LinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
def is_supported_and_can_implement_kernel(
|
||||
kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None
|
||||
) -> tuple[bool, str]:
|
||||
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
return False, f" {kernel.__name__} is disabled by environment variable"
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
is_supported, failure_reason = kernel.is_supported(compute_capability)
|
||||
if not is_supported:
|
||||
return False, f"{kernel.__name__} {failure_reason}."
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if not can_implement:
|
||||
return (
|
||||
False,
|
||||
f"{kernel.__name__} {failure_reason}.",
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: _KernelConfigT,
|
||||
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
|
||||
compute_capability: int | None = None,
|
||||
force_kernel: type[_KernelT] | None = None,
|
||||
) -> type[_KernelT]:
|
||||
"""
|
||||
Choose a _KernelT that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (_KernelConfigT): Description of the linear layer
|
||||
to be implemented.
|
||||
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
|
||||
dictionary of platforms and their list of possible kernels.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
|
||||
the possible_kernels if it can be implemented. If None, it will only try the
|
||||
possible kernels.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
_KernelT: Chosen kernel.
|
||||
"""
|
||||
|
||||
failure_reason_list = []
|
||||
|
||||
if force_kernel is not None:
|
||||
can_implement, failure_reason = is_supported_and_can_implement_kernel(
|
||||
force_kernel, config, compute_capability
|
||||
)
|
||||
if can_implement:
|
||||
return force_kernel
|
||||
|
||||
logger.info_once(
|
||||
"Tried to force %s, but the kernel couldn't be implemented",
|
||||
force_kernel.__name__,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
for kernel in possible_kernels[current_platform._enum]:
|
||||
is_supported_and_can_implement, failure_reason = (
|
||||
is_supported_and_can_implement_kernel(kernel, config, compute_capability)
|
||||
)
|
||||
if is_supported_and_can_implement:
|
||||
return kernel
|
||||
failure_reason_list.append(failure_reason)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
|
||||
)
|
||||
|
||||
|
||||
def init_fp8_linear_kernel(
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
out_dtype: torch.dtype,
|
||||
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
|
||||
module_name: str | None = None,
|
||||
) -> FP8ScaledMMLinearKernel:
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
weight_quant_key=weight_quant_key,
|
||||
activation_quant_key=activation_quant_key,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
|
||||
)
|
||||
|
||||
if module_name:
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
|
||||
)
|
||||
|
||||
|
||||
def init_int8_linear_kernel(
|
||||
is_channelwise: bool,
|
||||
is_static_input_scheme: bool,
|
||||
input_symmetric: bool,
|
||||
module_name: str,
|
||||
) -> Int8ScaledMMLinearKernel:
|
||||
config = Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=is_channelwise,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
config,
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
config,
|
||||
layer_param_names=[
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get
|
||||
the compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
continue
|
||||
if (
|
||||
compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"init_fp8_linear_kernel",
|
||||
"init_int8_linear_kernel",
|
||||
"choose_mp_linear_kernel",
|
||||
"FP8ScaledMMLinearKernel",
|
||||
"Int8ScaledMMLinearKernel",
|
||||
"ScaledMMLinearKernel",
|
||||
"FP8ScaledMMLinearLayerConfig",
|
||||
"Int8ScaledMMLinearLayerConfig",
|
||||
"ScaledMMLinearLayerConfig",
|
||||
"AiterInt8ScaledMMLinearKernel",
|
||||
"CPUInt8ScaledMMLinearKernel",
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: torch.dtype | None = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
if c.zero_points:
|
||||
assert w_zp_param_name is not None
|
||||
if c.has_g_idx:
|
||||
assert w_gidx_param_name is not None
|
||||
self.w_zp_name = w_zp_param_name
|
||||
self.w_gidx_name = w_gidx_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(
|
||||
self, layer: torch.nn.Module, name: str | None, fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||
)
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
torch.Tensor | None, # w_zp,
|
||||
torch.Tensor | None, # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
116
vllm/model_executor/kernels/linear/mixed_precision/allspark.py
Normal file
116
vllm/model_executor/kernels/linear/mixed_precision/allspark.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
check_allspark_supported_dtype_shape,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by AllSpark"
|
||||
|
||||
return check_allspark_supported_dtype_shape(
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = num_compute_units(device.index)
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args["sm_count"] = sm_count
|
||||
gemm_args["sm_version"] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
# transform param weight, scale
|
||||
old_weight_param = getattr(layer, self.w_q_name)
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(
|
||||
old_weight_param.data, requires_grad=False
|
||||
)
|
||||
new_weight_param.data = (
|
||||
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
|
||||
)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None, c.zero_points
|
||||
)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
b_qweight=w_q,
|
||||
b_scales=w_s,
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args["sm_count"],
|
||||
sm_version=gemm_args["sm_version"],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
140
vllm/model_executor/kernels/linear/mixed_precision/conch.py
Normal file
140
vllm/model_executor/kernels/linear/mixed_precision/conch.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint8,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
|
||||
|
||||
|
||||
class ConchLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = (
|
||||
f"Weight type ({c.weight_type}) not supported by "
|
||||
"ConchLinearKernel, supported types are: "
|
||||
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
|
||||
error_msg = (
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"ConchLinearKernel, supported group sizes are: "
|
||||
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if find_spec("conch") is None:
|
||||
error_msg = (
|
||||
"conch-triton-kernels is not installed, please "
|
||||
"install it via `pip install conch-triton-kernels` "
|
||||
"and try again!"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zero_point` is: {input_dim = 1, output_dim = 0, packed_dim = 0}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
# Zero points are stored PACKED as [N//pack_factor, K//G]
|
||||
# The Conch kernel expects UNPACKED zeros: [K//G, N]
|
||||
# We need to unpack and reorder
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
packed = x.data # shape: [N//pack_factor, K//G], dtype: int32
|
||||
|
||||
# Determine packing based on weight bit width
|
||||
size_bits = self.config.weight_type.size_bits
|
||||
pack_factor = 32 // size_bits # 8 for 4-bit, 4 for 8-bit
|
||||
mask = (1 << size_bits) - 1 # 0xF for 4-bit, 0xFF for 8-bit
|
||||
|
||||
n_packed, k_groups = packed.shape
|
||||
n_full = n_packed * pack_factor
|
||||
|
||||
# Unpack using vectorized bitwise ops
|
||||
# shifts = [0, size_bits, 2*size_bits, ...] for each packed position
|
||||
shifts = torch.arange(
|
||||
0, 32, size_bits, dtype=torch.int32, device=packed.device
|
||||
)
|
||||
# packed: [N//pack_factor, K//G] -> [N//pack_factor, K//G, 1]
|
||||
# shifts: [pack_factor] -> [1, 1, pack_factor]
|
||||
# Result: [N//pack_factor, K//G, pack_factor]
|
||||
unpacked = (packed.unsqueeze(-1) >> shifts) & mask
|
||||
|
||||
# Permute to [K//G, N//pack_factor, pack_factor] then reshape to [K//G, N]
|
||||
unpacked = unpacked.permute(1, 0, 2).reshape(k_groups, n_full)
|
||||
|
||||
x.data = unpacked.to(torch.uint8).contiguous()
|
||||
|
||||
# Update metadata - zeros are no longer packed
|
||||
if hasattr(x, "_input_dim"):
|
||||
x._input_dim = 0
|
||||
if hasattr(x, "_output_dim"):
|
||||
x._output_dim = 1
|
||||
if hasattr(x, "_packed_factor"):
|
||||
x._packed_factor = 1
|
||||
return x
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if self.config.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
output = mixed_precision_gemm(
|
||||
x=x,
|
||||
w_q_packed=w_q.data,
|
||||
w_s=w_s.data,
|
||||
w_zp=w_zp.data if w_zp is not None else None,
|
||||
weight_size_bits=self.config.weight_type.size_bits,
|
||||
weight_bias=self.config.weight_type.bias,
|
||||
group_size=self.config.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
126
vllm/model_executor/kernels/linear/mixed_precision/cpu.py
Normal file
126
vllm/model_executor/kernels/linear/mixed_precision/cpu.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
|
||||
class CPUWNA16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUWNA16 only supported on CPU"
|
||||
|
||||
if c.weight_type not in _CPUWNA16_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CPUWNA16, supported types are: "
|
||||
f"{_CPUWNA16_SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.group_size != -1 and c.group_size % 2 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"CPUWNA16, supported group sizes are multiples of 2",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[0] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Input size ({c.partition_weight_shape[0]}) not supported by "
|
||||
"CPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Output size ({c.partition_weight_shape[1]}) not supported by "
|
||||
"CPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def _process_gptq_weights(self, layer: torch.nn.Module):
|
||||
packed_weight = layer.qweight.data
|
||||
bits = self.config.weight_type.mantissa
|
||||
pack_factor = 32 // bits
|
||||
p_w_k, p_w_n = packed_weight.size()
|
||||
input_size = p_w_k * pack_factor
|
||||
output_size = p_w_n
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
layer.qzeros = None
|
||||
if not self.config.has_g_idx:
|
||||
layer.g_idx = None
|
||||
|
||||
# convert input dim packed to output dim packed
|
||||
weight = unpack_quantized_values_into_int32(
|
||||
packed_weight, self.config.weight_type, 1
|
||||
).view(p_w_k, p_w_n, pack_factor)
|
||||
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
|
||||
weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1)
|
||||
# make 16 output channel as a block and transpose to the make
|
||||
# the block contigous
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if not self.config.zero_points:
|
||||
# GPTQ
|
||||
self._process_gptq_weights(layer)
|
||||
else:
|
||||
# AWQ
|
||||
raise NotImplementedError("AWQ is not supported in CPUWNA16LinearKernel")
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = ops.cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
bias=bias,
|
||||
pack_factor=8, # 32 // 4
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
131
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py
Normal file
131
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# dynamic per-tok fp8 activation quantization
|
||||
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.act_type != torch.float8_e4m3fn:
|
||||
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
|
||||
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering not supported by CUTLASS W4A8"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points not supported by CUTLASS W4A8"
|
||||
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CUTLASS W4A8, only supported int4",
|
||||
)
|
||||
|
||||
if c.group_size != 128:
|
||||
return False, "Only group_size 128 is supported"
|
||||
|
||||
in_features, out_features = c.partition_weight_shape
|
||||
if in_features % 128 or out_features % 128:
|
||||
return (
|
||||
False,
|
||||
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
|
||||
)
|
||||
|
||||
if c.out_type != torch.bfloat16:
|
||||
return (
|
||||
False,
|
||||
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
|
||||
torch.cuda.synchronize()
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
|
||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||
return x
|
||||
|
||||
w_s = getattr(layer, self.w_s_name)
|
||||
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
|
||||
w_s.data = fp8_scales
|
||||
|
||||
# register per-channel scales
|
||||
layer.register_parameter(
|
||||
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
|
||||
)
|
||||
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
w_ch_s = layer.weight_chan_scale
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
x_2d, act_scales = self.quant_fp8(x_2d)
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=w_ch_s,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
# This implementation is for the KleidiAI-accelerated w4a8int quantization
|
||||
# scheme on Arm CPUs:
|
||||
# torch.ops.aten._dyn_quant_matmul_4bit performs dynamic quantized matmul
|
||||
# it takes:
|
||||
# - int4 weights packed along with bias/scales by
|
||||
# torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||
# - float32/bfloat16 activations
|
||||
# then it leverages KleidiAI ukernels that:
|
||||
# - dynamically quantize the activations to int8
|
||||
# - unpack the int4 weights to int8
|
||||
# - perform int8 x int8 -> int32 matmul
|
||||
# - dequantize the int32 output to float32/bfloat16 outputs
|
||||
class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Unsupported quant type {c.weight_type}"
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
and c.act_type
|
||||
not in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
]
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"Dynamic4bitLinearKernel on Arm requires Float32 or"
|
||||
" BFloat16 or Float16 activations",
|
||||
)
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
# Attempt to retrieve the operation
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
except AttributeError:
|
||||
return (
|
||||
False,
|
||||
f"PyTorch {torch.__version__} does not support"
|
||||
" _dyn_quant_matmul_4bit. Install a newer version",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
packed_weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = packed_weight.add(8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
|
||||
torch.uint8
|
||||
)
|
||||
|
||||
scales = getattr(layer, self.w_s_name)
|
||||
block_size = c.group_size
|
||||
|
||||
# Handle scaling factors for partitioned weights
|
||||
if block_size == c.partition_weight_shape[0]:
|
||||
scales = scales.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 scales
|
||||
scales = scales.view(-1, 1) # Channel-wise scales
|
||||
if layer.bias is not None:
|
||||
# Float32 & Bfloat16 variants requires float32 bias
|
||||
replace_parameter(
|
||||
layer,
|
||||
"bias",
|
||||
torch.nn.Parameter(
|
||||
layer.bias.to(torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
else:
|
||||
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||
scales = scales.to(torch.bfloat16)
|
||||
|
||||
# Repack weights as per kernel requirement
|
||||
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_packed,
|
||||
scales,
|
||||
layer.bias,
|
||||
block_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
|
||||
)
|
||||
setattr(layer, self.w_s_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# PyTorch / KleidiAI kernels natively support the following configs:
|
||||
# - channelwise with bfloat16 / float32 activations
|
||||
# - groupwise with float32 activations
|
||||
# To support:
|
||||
# - groupwise with bfloat16/float16 activations: we need to upcast
|
||||
# activations to float32 before matmul and downcast back to bfloat16/float16
|
||||
# - channelwise with float16 activations, we need to upcast activations to
|
||||
# float32 before matmul and downcast back to float16
|
||||
# Note: these activations will be dynamically quantized to int8 by the kernel.
|
||||
|
||||
c = self.config
|
||||
is_groupwise = c.group_size != c.partition_weight_shape[0]
|
||||
# dtype of activations before they get dynamically quantized to int8
|
||||
original_pre_quant_act_dtype = x.dtype
|
||||
pre_quant_act_dtype = original_pre_quant_act_dtype
|
||||
if (
|
||||
is_groupwise and pre_quant_act_dtype == torch.bfloat16
|
||||
) or pre_quant_act_dtype == torch.float16:
|
||||
pre_quant_act_dtype = torch.float32
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
x_2d = x_2d.to(pre_quant_act_dtype)
|
||||
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
x_2d,
|
||||
w_q,
|
||||
c.group_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
).reshape(out_shape)
|
||||
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
output = output.to(original_pre_quant_act_dtype)
|
||||
return output
|
||||
168
vllm/model_executor/kernels/linear/mixed_precision/exllama.py
Normal file
168
vllm/model_executor/kernels/linear/mixed_precision/exllama.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class ExllamaLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
|
||||
# currently untested so not added to the list
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda_alike():
|
||||
return (
|
||||
False,
|
||||
"Exllama is only supported on CUDA and ROCm",
|
||||
)
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Exllama, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||
return (
|
||||
False,
|
||||
"Output features must be a multiple of the pack "
|
||||
"factor (32 / num_bits) so that we can correctly "
|
||||
"pack the zero points",
|
||||
)
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
return False, "Exllama only supports float16 activations"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Exllama, supported types are: "
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# For Exllama, we need to set a zero-point tensor if there is not one
|
||||
if not c.zero_points:
|
||||
self.w_zp_name = "qzeros"
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
groups = c.partition_weight_shape[0] // c.group_size
|
||||
out_features = c.partition_weight_shape[1]
|
||||
|
||||
if c.weight_type.has_bias():
|
||||
# if the type has a bias we have to create a zeros tensor that
|
||||
# contains the bias values repeated for each group (-1 due to
|
||||
# a bug in the original GPTQ checkpoint format leading to
|
||||
# exllama kernel adding 1 to the zero points during inference)
|
||||
# Documentation of the bug can be found here:
|
||||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||
zeros = torch.full(
|
||||
(groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"A 0 zero-point is not supported by Exllama due to "
|
||||
"a bug in the original GPTQ checkpoint format leading to "
|
||||
"exllama kernel adding 1 to the zero points during "
|
||||
"inference"
|
||||
)
|
||||
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
|
||||
setattr(
|
||||
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
|
||||
)
|
||||
|
||||
if c.has_g_idx:
|
||||
|
||||
def transform_w_g_idx(x):
|
||||
# Exllama wants the permutation array instead of the group
|
||||
# indices
|
||||
return torch.argsort(x).to(torch.int)
|
||||
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) # type: ignore
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(
|
||||
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
assert self.w_gidx_name is not None
|
||||
g_idx = getattr(layer, self.w_gidx_name)
|
||||
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x_cont = x.data.contiguous()
|
||||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||
return x_cont
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
|
||||
# However, the MPLinearLayerConfig doesn't contain format info.
|
||||
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
|
||||
use_v2_format = False
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(
|
||||
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
159
vllm/model_executor/kernels/linear/mixed_precision/machete.py
Normal file
159
vllm/model_executor/kernels/linear/mixed_precision/machete.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
check_machete_supports_shape,
|
||||
query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Machete only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "Machete requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Machete, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Machete, supported types are: "
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}",
|
||||
)
|
||||
|
||||
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Machete, supported group sizes are: "
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}",
|
||||
)
|
||||
|
||||
return check_machete_supports_shape(
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1]
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if (
|
||||
c.act_type in [torch.float16, torch.bfloat16]
|
||||
and c.partition_weight_shape[0] % 8 == 0
|
||||
):
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=0
|
||||
)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(
|
||||
x_perm, c.weight_type, packed_dim=0
|
||||
)
|
||||
x.data = ops.machete_prepack_B(
|
||||
x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=1
|
||||
)
|
||||
w_s = getattr(layer, self.w_s_name).data
|
||||
# pre-apply scales to zero-points
|
||||
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
|
||||
return x
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if c.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
if c.zero_points:
|
||||
assert w_zp is not None
|
||||
else:
|
||||
w_zp = None
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
200
vllm/model_executor/kernels/linear/mixed_precision/marlin.py
Normal file
200
vllm/model_executor/kernels/linear/mixed_precision/marlin.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
marlin_sort_g_idx,
|
||||
marlin_zero_points,
|
||||
query_marlin_supported_quant_types,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Marlin only supported on CUDA"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by"
|
||||
f" Marlin, supported types are: {quant_types}",
|
||||
)
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Marlin, supported group sizes are: "
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
|
||||
)
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert c.weight_type == scalar_types.uint4b8, (
|
||||
"W8A8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if c.act_type == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
|
||||
getattr(layer, self.w_s_name).data = (
|
||||
getattr(layer, self.w_s_name).data * 512
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
if c.group_size == -1:
|
||||
num_groups = 1
|
||||
else:
|
||||
num_groups = c.partition_weight_shape[0] // c.group_size
|
||||
|
||||
if c.act_type == torch.int8 and num_groups > 1:
|
||||
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
|
||||
layer.register_parameter(
|
||||
"input_global_scale",
|
||||
torch.nn.Parameter(input_global_scale, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
layer.input_global_scale = None
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
self._transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=c.act_type,
|
||||
)
|
||||
88
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
Normal file
88
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
|
||||
class XPUwNa16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUwNa16 only supported on XPU"
|
||||
|
||||
if c.act_type != torch.bfloat16 and c.act_type != torch.float16:
|
||||
return False, "XPUwNa16 only supports BF16/FP16 activations"
|
||||
|
||||
if c.weight_type not in _XPUWNA16_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"XPUwNa16, supported types are: "
|
||||
f"{_XPUWNA16_SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
if c.group_size != -1 and c.group_size % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"XPUwNa16, supported group sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[0] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Input size ({c.partition_weight_shape[0]}) not supported by "
|
||||
"XPUwNa16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Output size ({c.partition_weight_shape[1]}) not supported by "
|
||||
"XPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
layer.weight_scale.data = layer.weight_scale.t().contiguous()
|
||||
|
||||
if self.config.zero_points:
|
||||
layer.weight_zero_point.data = layer.weight_zero_point.t().contiguous()
|
||||
else:
|
||||
weight_zero_point = torch.Tensor([8]).to(torch.int8).to("xpu")
|
||||
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
|
||||
if self.config.has_g_idx:
|
||||
layer.g_idx.data = layer.g_idx.t().contiguous()
|
||||
else:
|
||||
layer.g_idx = None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = torch.ops._xpu_C.int4_gemm_w4a16(
|
||||
reshaped_x,
|
||||
layer.weight_packed.t(),
|
||||
bias,
|
||||
layer.weight_scale,
|
||||
layer.weight_zero_point,
|
||||
self.config.group_size,
|
||||
layer.g_idx,
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,187 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
|
||||
is_static_input_scheme: bool
|
||||
is_channelwise: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_key: QuantKey
|
||||
activation_quant_key: QuantKey
|
||||
out_dtype: torch.dtype | None
|
||||
|
||||
|
||||
_FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_scale_ub,
|
||||
]
|
||||
_Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
|
||||
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||
assert self.can_implement(c)[0]
|
||||
assert self.is_supported()[0]
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer) -> _ParamsT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FP8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
|
||||
):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
act_scale_descriptor = c.activation_quant_key.scale
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=act_scale_descriptor.static,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_output_padding(),
|
||||
)
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def _get_layer_params(self, layer) -> _FP8ParamsT:
|
||||
w, w_s, x_s, x_s_ub = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, x_s, None),
|
||||
getattr(layer, x_s_ub, None),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
fp8_dtype = self.fp8_dtype
|
||||
maybe_out_dtype = self.config.out_dtype
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != fp8_dtype:
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
x_s_ub,
|
||||
)
|
||||
return self.apply_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
||||
):
|
||||
def _get_layer_params(self, layer) -> _Int8ParamsT:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w_q),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, i_s, None),
|
||||
getattr(layer, i_zp, None),
|
||||
getattr(layer, azp_adj, None),
|
||||
)
|
||||
54
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
Normal file
54
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FP8ScaledMMLinearKernel",
|
||||
"FP8ScaledMMLinearLayerConfig",
|
||||
"Int8ScaledMMLinearKernel",
|
||||
"Int8ScaledMMLinearLayerConfig",
|
||||
"ScaledMMLinearKernel",
|
||||
"ScaledMMLinearLayerConfig",
|
||||
"AiterInt8ScaledMMLinearKernel",
|
||||
"CPUInt8ScaledMMLinearKernel",
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
]
|
||||
109
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
Normal file
109
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "Requires ROCm."
|
||||
|
||||
if compute_capability is not None and compute_capability < 90:
|
||||
return False, "requires compute capability 90 and above."
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
except Exception:
|
||||
return False, "requires `aiter` to be installed."
|
||||
|
||||
if not rocm_aiter_ops.is_linear_enabled():
|
||||
return (
|
||||
False,
|
||||
"requires setting `VLLM_ROCM_USE_AITER=1` "
|
||||
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
|
||||
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return False, "supports symmetric quantization only."
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
`AiterInt8ScaledMMLinearKernel` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
Currently only support per-tensor-per-tensor GEMM
|
||||
and per-token-per-channel GEMM through AITER
|
||||
w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
assert symmetric, (
|
||||
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
|
||||
)
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
|
||||
|
||||
assert x_zp is None, (
|
||||
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
|
||||
)
|
||||
out_dtype = x.dtype
|
||||
|
||||
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
|
||||
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
||||
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
|
||||
|
||||
m = x_q.shape[0] # a
|
||||
n = w_q.shape[1] # b
|
||||
|
||||
per_tensor_scale_a = x_s.numel() == 1
|
||||
per_tensor_scale_b = w_s.numel() == 1
|
||||
per_token_scale_a = x_s.numel() == m
|
||||
per_channel_scale_b = w_s.numel() == n
|
||||
|
||||
# @TODO:
|
||||
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
||||
# if one of the scale is a per-channel-scale.
|
||||
# For now, it only supports:
|
||||
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
||||
# - per-token-per-channel a8w8 scaled GEMM
|
||||
assert (per_tensor_scale_a and per_tensor_scale_b) or (
|
||||
per_token_scale_a and per_channel_scale_b
|
||||
), (
|
||||
"Currently only support per-tensor-per-tensor GEMM "
|
||||
" and per-token-per-channel GEMM through AITER"
|
||||
" w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
|
||||
"does not support AITER block scaled GEMM."
|
||||
)
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
217
vllm/model_executor/kernels/linear/scaled_mm/cpu.py
Normal file
217
vllm/model_executor/kernels/linear/scaled_mm/cpu.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "requires CPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, _, _, _, _ = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
dtype = weight.dtype
|
||||
N, K = weight.size()
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||
and envs.VLLM_CPU_SGL_KERNEL
|
||||
and self.config.input_symmetric
|
||||
and check_cpu_sgl_kernel(N, K, dtype)
|
||||
):
|
||||
self.linear_method = self._apply_weights_sgl
|
||||
self.process_weights_for_sgl(layer)
|
||||
else:
|
||||
self.linear_method = self._apply_weights_onednn
|
||||
self.process_weights_for_onednn(layer)
|
||||
|
||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Transpose to [K, N] for convenience
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# oneDNN kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
azp = (
|
||||
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
||||
# term for dynamic quantization. And s_b should be folded into the
|
||||
# term. Such as:
|
||||
# s_a * s_b * [(A - zp_a)B] + bias =
|
||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
||||
weight = getattr(layer, w_q_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
azp_adj = azp_adj * weight_scale.squeeze()
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
weight = getattr(layer, w_q_name)
|
||||
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
||||
weight,
|
||||
getattr(layer, w_s_name),
|
||||
torch.get_default_dtype(),
|
||||
getattr(layer, i_s_name) is None,
|
||||
not self.config.input_symmetric,
|
||||
32,
|
||||
)
|
||||
# weight is prepacked and maintained by the dnnl_handler,
|
||||
# release the original weight
|
||||
setattr(layer, w_q_name, None)
|
||||
del weight
|
||||
|
||||
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||
# WEIGHT
|
||||
weight = getattr(layer, w_q_name)
|
||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||
replace_parameter(
|
||||
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
)
|
||||
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias
|
||||
layer.register_parameter(
|
||||
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# CPU SGL kernels only support per-channel.
|
||||
# For per-tensor quant, convert to the per-channel case.
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.linear_method(
|
||||
layer,
|
||||
x,
|
||||
bias,
|
||||
)
|
||||
|
||||
def _apply_weights_onednn(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
x = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
|
||||
x, i_s, i_zp, self.config.input_symmetric
|
||||
)
|
||||
|
||||
m = x.size(0)
|
||||
n = self.dnnl_handler.n
|
||||
out = torch.empty((m, n), dtype=x.dtype)
|
||||
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
|
||||
out = out.reshape(x_shape[:-1] + (n,)) if len(x_shape) > 2 else out
|
||||
return out
|
||||
|
||||
def _apply_weights_sgl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
x,
|
||||
w_q,
|
||||
w_s,
|
||||
layer.bias_fp32 if bias is not None else None,
|
||||
x.dtype,
|
||||
True,
|
||||
)
|
||||
176
vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
Normal file
176
vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
# torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=symmetric
|
||||
)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
static = i_zp is not None
|
||||
azp = None if static else x_zp
|
||||
return ops.cutlass_scaled_mm_azp(
|
||||
x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.cutlass_scaled_mm(
|
||||
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
|
||||
)
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
57
vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
Normal file
57
vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
|
||||
if not has_flashinfer():
|
||||
return False, "requires FlashInfer to be installed."
|
||||
|
||||
if compute_capability is not None and compute_capability < 100:
|
||||
return False, "requires compute capability 100 and above."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return False, "requires per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_scaled_fp8_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
217
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
Normal file
217
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
"""
|
||||
Base class for FP8 linear kernels using Torch.
|
||||
Each subclass represents a kernel variant for
|
||||
specific device capabilities and torch versions.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
|
||||
return False, "requires ROCm, CUDA or CPU."
|
||||
|
||||
if compute_capability is not None and compute_capability < 89:
|
||||
return False, "requires compute capability 89 and above."
|
||||
|
||||
return True, None
|
||||
|
||||
def get_output_padding(self) -> int | None:
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
#
|
||||
# The perf gain is still relevant as of 16/1/2026
|
||||
# torch version == 2.9.0. More details in the link below:
|
||||
# https://github.com/vllm-project/vllm/issues/32269
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
return 17 if pad_output else None
|
||||
|
||||
|
||||
class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return False, "requires per tensor activation and weight scales."
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch._scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "requires ROCm."
|
||||
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
if not on_mi3xx():
|
||||
return False, "requires MI3xx."
|
||||
|
||||
if compute_capability is not None and compute_capability < 94:
|
||||
return False, "requires compute capability 94 and above."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if c.out_dtype == torch.float16:
|
||||
# hipblaslt rowwise _scaled_mm only supports BFloat16
|
||||
return False, "supports BFloat16 output data type only."
|
||||
|
||||
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||
return False, "cannot be used with per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Note:
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.t(),
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return False, "cannot be used with per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
|
||||
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
scale_a=dummy_tensor,
|
||||
scale_b=dummy_tensor,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||
x_scale = torch.narrow(As, 0, 0, output_shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * Bs.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
118
vllm/model_executor/kernels/linear/scaled_mm/rocm.py
Normal file
118
vllm/model_executor/kernels/linear/scaled_mm/rocm.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
A.shape[0] <= 4
|
||||
and B.shape[0] % 16 == 0 # M TODO: needed?
|
||||
and B.shape[1] % 16 == 0 # K
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
A,
|
||||
out_dtype,
|
||||
As,
|
||||
Bs,
|
||||
num_compute_units(),
|
||||
bias,
|
||||
)
|
||||
# Fallback
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs,
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "requires ROCm."
|
||||
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
if not on_mi3xx():
|
||||
return False, "requires MI3xx."
|
||||
|
||||
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
|
||||
return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return False, "requires per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A, B, out_dtype, As, Bs, bias
|
||||
)
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
93
vllm/model_executor/kernels/linear/scaled_mm/triton.py
Normal file
93
vllm/model_executor/kernels/linear/scaled_mm/triton.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
|
||||
triton_scaled_mm,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cuda_alike():
|
||||
return True, None
|
||||
return False, "requires ROCm or CUDA."
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return False, "supports symmetric input only."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q, _, i_s, _, _ = self._get_layer_params(layer)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(w_q.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Triton kernel supports only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
assert i_s is not None
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(i_s.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
|
||||
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=True
|
||||
)
|
||||
|
||||
assert x_zp is None, "Triton kernel only supports symmetric quantization"
|
||||
|
||||
return triton_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
59
vllm/model_executor/kernels/linear/scaled_mm/xpu.py
Normal file
59
vllm/model_executor/kernels/linear/scaled_mm/xpu.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.kernels.linear import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUFP8ScaledMM only support on XPU"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
|
||||
return True, None
|
||||
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
assert self.can_implement(c)[0]
|
||||
assert self.is_supported()[0]
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
pass
|
||||
Reference in New Issue
Block a user