Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.parameter import BasevLLMParameter, PackedvLLMParameter
__all__ = [
"BasevLLMParameter",
"PackedvLLMParameter",
]

View File

@@ -0,0 +1,346 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import inspect
import torch
import torch.nn as nn
from vllm.config import get_cached_compilation_config
from vllm.logger import init_logger
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Dictionary of all custom ops (classes, indexed by registered name).
# To check if an op with a name is enabled, call .enabled() on the class.
# Examples:
# - MyOp.enabled()
# - op_registry["my_op"].enabled()
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
class PluggableLayer(nn.Module):
"""
Base class for pluggable layers.
A PluggableLayer is a *module-composing* abstraction: it may instantiate other
``torch.nn.Module`` objects as sub-layers, and its functionality depends on
these sub-layers following a generalized invocation sequence. Also, it is stateful
and may hold parameters or buffers.
Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
of the entire layer class at instantiation time, allowing customized
initialization and submodule composition.
"""
def __new__(cls, *args, **kwargs):
try:
layer_class_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@PluggableLayer.register, or it's the PluggableLayer itself."
) from None
if layer_class_name not in op_registry_oot:
layer_cls_to_instantiate = cls
else:
layer_cls_to_instantiate = op_registry_oot[layer_class_name]
logger.debug(
"Instantiating pluggable layer: %s using %s",
layer_class_name,
str(layer_cls_to_instantiate),
)
return super().__new__(layer_cls_to_instantiate)
# Decorator to register pluggable layers.
@classmethod
def register(cls, name: str):
def decorator(op_cls):
assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name
op_registry[name] = op_cls
return op_cls
return decorator
# Decorator to register out-of-tree(oot) pluggable layers.
# For OOT pluggable layers:
# if in-tree layer class is registered with an oot_custom_layer,
# the oot_custom_layer will be used instead.
@classmethod
def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
def decorator(layer_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
layer_cls.name = reg_name
op_registry_oot[reg_name] = layer_cls
return layer_cls
if _decorated_layer_cls is None:
# Called with parentheses: @PluggableLayer.register_oot()
# or @PluggableLayer.register_oot(name="...")
return decorator
elif isinstance(_decorated_layer_cls, type): # Check if it's a class
# Called without parentheses: @PluggableLayer.register_oot
return decorator(_decorated_layer_cls)
else:
raise TypeError("Decorator can only be applied to classes.")
class CustomOp(nn.Module):
"""
Base class for custom ops.
Dispatches the forward method to the appropriate backend.
"""
def __new__(cls, *args, **kwargs):
try:
op_name = cls.__name__
except AttributeError:
raise TypeError(
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
f"was not set, possibly because it was not decorated with "
f"@CustomOp.register, or it's the CustomOp base class itself."
) from None
if op_name not in op_registry_oot:
op_cls_to_instantiate = cls
else:
op_cls_to_instantiate = op_registry_oot[op_name]
logger.debug(
"Instantiating custom op: %s using %s",
op_name,
str(op_cls_to_instantiate),
)
return super().__new__(op_cls_to_instantiate)
def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
super().__init__()
self._enforce_enable = enforce_enable
self._forward_method = self.dispatch_forward(compile_native=compile_native)
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_oot(self, *args, **kwargs):
# By default, we assume that OOT ops are compatible with the
# PyTorch-native implementation.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self, compile_native: bool):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_cached_compilation_config()
# NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
# enable device-specific kernels in ViT models when enabling graph
# mode. By default, it will follow the compilation_config to determine
# whether enable itself.
# This enforce_enable mechanism will be removed after we adding a
# separate compilation_config for multi-modal part.
enabled = self._enforce_enable or self.enabled()
if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name])
else:
compilation_config.disabled_custom_ops.update([self.__class__.name])
if not enabled:
# Compile forward_native to avoid eager torch ops if inside
# opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
return self.maybe_compile(self.forward_native, enable=compile_native)
if current_platform.is_rocm():
return self.forward_hip
elif current_platform.is_cpu():
return self.forward_cpu
elif current_platform.is_tpu():
return self.forward_tpu
elif current_platform.is_xpu():
return self.forward_xpu
elif current_platform.is_out_of_tree():
return self.forward_oot
else:
return self.forward_cuda
def maybe_compile(self, fn, *, enable: bool = True):
"""
Compile fn if compilation enabled.
Useful for CustomOp instances called from within a torch custom op,
meaning the forward call is hidden from the model-level torch.compile.
NOTE: this does not enable fusion across ops, so opaque custom ops
should still be unwrapped wherever possible.
"""
from vllm.config.compilation import CompilationMode
# Do not compile if compilation disabled
if not enable:
return fn
# Do not compile if global compilation disabled
compilation_config = get_cached_compilation_config()
if compilation_config.mode == CompilationMode.NONE:
return fn
# If eager backend is used, do not compile either
if compilation_config.backend == "eager":
return fn
compile_options = maybe_disable_graph_partition(
current_platform.simple_compile_backend
)
backend = current_platform.simple_compile_backend
dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
if dynamic_arg_dims is not None:
compiled_fn = torch.compile(
fn,
dynamic=False,
backend=backend,
options=compile_options,
)
sig = inspect.signature(fn)
@functools.wraps(fn)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for name, dims in dynamic_arg_dims.items():
arg = bound.arguments.get(name)
if arg is not None and isinstance(arg, torch.Tensor):
dims_list = [dims] if isinstance(dims, int) else dims
for d in dims_list:
real_d = arg.ndim + d if d < 0 else d
torch._dynamo.mark_dynamic(arg, real_d)
return compiled_fn(*args, **kwargs)
return wrapper
# dynamic=True to avoid recompilations
return torch.compile(
fn,
dynamic=True,
backend=backend,
options=compile_options,
)
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
compilation_config = get_cached_compilation_config()
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
"Custom op %s was not registered, which means it won't appear "
"in the op registry. It will be enabled/disabled based on the "
"global settings.",
cls.__name__,
)
return CustomOp.default_on()
enabled = f"+{cls.name}" in custom_ops
disabled = f"-{cls.name}" in custom_ops
assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
return (CustomOp.default_on() or enabled) and not disabled
@staticmethod
def default_on() -> bool:
"""
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
'all', off by default if 'none'.
When PyTorch Inductor is used, 'none' is the default value,
otherwise 'all'.
"""
compilation_config = get_cached_compilation_config()
count_none = compilation_config.custom_ops.count("none")
count_all = compilation_config.custom_ops.count("all")
assert count_none + count_all == 1
return not count_none > 0 or count_all > 0
# Decorator to register custom ops.
@classmethod
def register(
cls,
name: str,
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
):
def decorator(op_cls):
assert name not in op_registry, f"Duplicate op name: {name}"
op_cls.name = name
op_cls._dynamic_arg_dims = dynamic_arg_dims
op_registry[name] = op_cls
return op_cls
return decorator
# Decorator to register out-of-tree(oot) custom ops.
# For OOT custom ops:
# if in-tree layer class is registered with an oot_custom_op layer,
# the oot_custom_op layer will be used instead.
# Example:
# - @UnquantizedFusedMoEMethod.register_oot
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
# or
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
@classmethod
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
def decorator(op_cls):
reg_name = name if name is not None else cls.__name__
assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
op_cls.name = reg_name
op_registry_oot[reg_name] = op_cls
return op_cls
if _decorated_op_cls is None:
# Called with parentheses: @CustomOP.register_oot()
# or @CustomOP.register_oot(name="...")
# So, _decorated_op_cls is None.
# We return the actual decorator function.
return decorator
elif isinstance(_decorated_op_cls, type): # Check if it's a class
# Called without parentheses: @CustomOP.register_oot
# The first argument is the class itself.
# We call the 'decorator' function immediately with the class.
return decorator(_decorated_op_cls)
else:
# Handle other unexpected cases if necessary
raise TypeError("Decorator can only be applied to classes.")

View File

View File

@@ -0,0 +1,397 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
This module re-exports linear kernel implementations to provide a
stable import interface during an ongoing reorganization. Upcoming
PRs will remove the scaled_mm and mixed_precision subdirectories
and reorganize kernels by provider (aiter, cutlass, flashinfer, etc.)
rather than by precision type. By centralizing exports here, we
minimize the need to update imports across other modules when the
internal structure changes. If you are adding a new kernel selector
or kernel implementation, add it to this __init__.py to maintain
import stability.
"""
import os
from typing import TypeVar
import torch
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear.mixed_precision import (
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
AllSparkLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
ConchLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
CPUWNA16LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
CutlassW4A8LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
Dynamic4bitLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
ExllamaLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
MacheteLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUwNa16LinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
XPUFP8ScaledMMLinearKernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
from vllm.platforms import PlatformEnum, current_platform
logger = init_logger(__name__)
# in priority/performance order (when available)
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
PlatformEnum.CUDA: [
CutlassInt8ScaledMMLinearKernel,
TritonInt8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [AiterInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel],
}
# in priority/performance order (when available)
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferFP8ScaledMMLinearKernel,
CutlassFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
ROCmFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.CPU: [
PerTensorTorchFP8ScaledMMLinearKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
],
PlatformEnum.XPU: [
XPUFP8ScaledMMLinearKernel,
],
}
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
CutlassW4A8LinearKernel,
MacheteLinearKernel,
AllSparkLinearKernel,
MarlinLinearKernel,
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.ROCM: [
ConchLinearKernel,
ExllamaLinearKernel,
],
PlatformEnum.XPU: [
XPUwNa16LinearKernel,
],
PlatformEnum.CPU: [
Dynamic4bitLinearKernel,
CPUWNA16LinearKernel,
],
}
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
def is_supported_and_can_implement_kernel(
kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None
) -> tuple[bool, str]:
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
return False, f" {kernel.__name__} is disabled by environment variable"
if compute_capability is None:
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
is_supported, failure_reason = kernel.is_supported(compute_capability)
if not is_supported:
return False, f"{kernel.__name__} {failure_reason}."
can_implement, failure_reason = kernel.can_implement(config)
if not can_implement:
return (
False,
f"{kernel.__name__} {failure_reason}.",
)
return True, ""
def choose_scaled_mm_linear_kernel(
config: _KernelConfigT,
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
compute_capability: int | None = None,
force_kernel: type[_KernelT] | None = None,
) -> type[_KernelT]:
"""
Choose a _KernelT that can implement the given config for the
given compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (_KernelConfigT): Description of the linear layer
to be implemented.
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
dictionary of platforms and their list of possible kernels.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the
compute capability. Defaults to None.
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
the possible_kernels if it can be implemented. If None, it will only try the
possible kernels.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
_KernelT: Chosen kernel.
"""
failure_reason_list = []
if force_kernel is not None:
can_implement, failure_reason = is_supported_and_can_implement_kernel(
force_kernel, config, compute_capability
)
if can_implement:
return force_kernel
logger.info_once(
"Tried to force %s, but the kernel couldn't be implemented",
force_kernel.__name__,
scope="global",
)
for kernel in possible_kernels[current_platform._enum]:
is_supported_and_can_implement, failure_reason = (
is_supported_and_can_implement_kernel(kernel, config, compute_capability)
)
if is_supported_and_can_implement:
return kernel
failure_reason_list.append(failure_reason)
raise ValueError(
"Failed to find a kernel that can implement the "
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
)
def init_fp8_linear_kernel(
activation_quant_key: QuantKey,
weight_quant_key: QuantKey,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
scaled_mm_linear_kernel_config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
def init_int8_linear_kernel(
is_channelwise: bool,
is_static_input_scheme: bool,
input_symmetric: bool,
module_name: str,
) -> Int8ScaledMMLinearKernel:
config = Int8ScaledMMLinearLayerConfig(
is_channelwise=is_channelwise,
is_static_input_scheme=is_static_input_scheme,
input_symmetric=input_symmetric,
)
kernel_type = choose_scaled_mm_linear_kernel(
config,
_POSSIBLE_INT8_KERNELS,
)
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
config,
layer_param_names=[
"weight",
"weight_scale",
"input_scale",
"input_zero_point",
"azp_adj",
],
)
def choose_mp_linear_kernel(
config: MPLinearLayerConfig, compute_capability: int | None = None
) -> type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get
the compute capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
if _cc is not None:
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel.__name__} disabled by environment variable"
)
continue
if (
compute_capability is not None
and kernel.get_min_capability() > compute_capability
):
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute "
f" capability is {compute_capability}"
)
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f" {kernel.__name__} cannot implement due to: {failure_reason}"
)
raise ValueError(
"Failed to find a kernel that can implement the "
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)
__all__ = [
"init_fp8_linear_kernel",
"init_int8_linear_kernel",
"choose_mp_linear_kernel",
"FP8ScaledMMLinearKernel",
"Int8ScaledMMLinearKernel",
"ScaledMMLinearKernel",
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearLayerConfig",
"AiterInt8ScaledMMLinearKernel",
"CPUInt8ScaledMMLinearKernel",
"CutlassFP8ScaledMMLinearKernel",
"CutlassInt8ScaledMMLinearKernel",
"FlashInferFP8ScaledMMLinearKernel",
"ChannelWiseTorchFP8ScaledMMLinearKernel",
"PerTensorTorchFP8ScaledMMLinearKernel",
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
"MPLinearKernel",
"MPLinearLayerConfig",
"AllSparkLinearKernel",
"ConchLinearKernel",
"CPUWNA16LinearKernel",
"CutlassW4A8LinearKernel",
"Dynamic4bitLinearKernel",
"ExllamaLinearKernel",
"MacheteLinearKernel",
"MarlinLinearKernel",
"XPUwNa16LinearKernel",
]

View File

@@ -0,0 +1,94 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: tuple[int, int] # [in, out]
partition_weight_shape: tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
out_type: torch.dtype | None = None
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(
self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: str | None = None,
w_gidx_param_name: str | None = None,
) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
if c.zero_points:
assert w_zp_param_name is not None
if c.has_g_idx:
assert w_gidx_param_name is not None
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
def _transform_param(
self, layer: torch.nn.Module, name: str | None, fn: Callable
) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
)
def _get_weight_params(
self, layer: torch.nn.Module
) -> tuple[
torch.Tensor, # w_q
torch.Tensor, # w_s
torch.Tensor | None, # w_zp,
torch.Tensor | None, # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)

View File

@@ -0,0 +1,48 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
AllSparkLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
ConchLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
CPUWNA16LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
CutlassW4A8LinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
Dynamic4bitLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
ExllamaLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
MacheteLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
MarlinLinearKernel,
)
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
MPLinearKernel,
MPLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUwNa16LinearKernel,
)
__all__ = [
"MPLinearKernel",
"MPLinearLayerConfig",
"AllSparkLinearKernel",
"ConchLinearKernel",
"CPUWNA16LinearKernel",
"CutlassW4A8LinearKernel",
"Dynamic4bitLinearKernel",
"ExllamaLinearKernel",
"MacheteLinearKernel",
"MarlinLinearKernel",
"XPUwNa16LinearKernel",
]

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
check_allspark_supported_dtype_shape,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.utils.platform_utils import num_compute_units
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class AllSparkLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.has_g_idx:
return False, "Act reordering currently not supported by AllSpark"
if c.zero_points:
return False, "Zero points currently not supported by AllSpark"
return check_allspark_supported_dtype_shape(
c.partition_weight_shape[0], # in_features
c.partition_weight_shape[1], # out_features
c.group_size,
c.weight_type,
c.act_type,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
# prepare the parameters required for the kernel
properties = torch.cuda.get_device_properties(device.index)
sm_count = num_compute_units(device.index)
sm_version = properties.major * 10 + properties.minor
gemm_args = {}
gemm_args["sm_count"] = sm_count
gemm_args["sm_version"] = sm_version
self.gemm_args = gemm_args
# transform param weight, scale
old_weight_param = getattr(layer, self.w_q_name)
old_scale_param = getattr(layer, self.w_s_name)
assert isinstance(old_weight_param, BasevLLMParameter)
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
assert isinstance(old_scale_param, BasevLLMParameter)
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
# unpack weight from K / 4 x N int32 to K x N uint8
new_weight_param = torch.nn.Parameter(
old_weight_param.data, requires_grad=False
)
new_weight_param.data = (
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
)
new_weight_param.data = new_weight_param.data.t().contiguous()
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
# reorder K x N weight as N32K16 format for Ampere W8A16
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
new_weight_param.data, new_scale_param.data, None, c.zero_points
)
replace_parameter(layer, self.w_q_name, new_weight_param.data)
replace_parameter(layer, self.w_s_name, new_scale_param.data)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
gemm_args = self.gemm_args
w_q, w_s, _, _ = self._get_weight_params(layer)
reshaped_x = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
output = ops.allspark_w8a16_gemm(
a=reshaped_x,
b_qweight=w_q,
b_scales=w_s,
b_qzeros=None,
n=c.partition_weight_shape[1],
group_size=c.group_size,
sm_count=gemm_args["sm_count"],
sm_version=gemm_args["sm_version"],
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
has_zp=c.zero_points,
n32k16_reorder=True,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from importlib.util import find_spec
from typing import Final
import torch
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
scalar_types.uint4,
scalar_types.uint8,
scalar_types.uint4b8,
scalar_types.uint8b128,
]
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
class ConchLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
error_msg = (
f"Weight type ({c.weight_type}) not supported by "
"ConchLinearKernel, supported types are: "
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
)
return False, error_msg
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
error_msg = (
f"Group size ({c.group_size}) not supported by "
"ConchLinearKernel, supported group sizes are: "
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
)
return False, error_msg
if find_spec("conch") is None:
error_msg = (
"conch-triton-kernels is not installed, please "
"install it via `pip install conch-triton-kernels` "
"and try again!"
)
return False, error_msg
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zero_point` is: {input_dim = 1, output_dim = 0, packed_dim = 0}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = x.data.contiguous()
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
def transform_w_zp(x):
# Zero points are stored PACKED as [N//pack_factor, K//G]
# The Conch kernel expects UNPACKED zeros: [K//G, N]
# We need to unpack and reorder
assert isinstance(x, BasevLLMParameter)
packed = x.data # shape: [N//pack_factor, K//G], dtype: int32
# Determine packing based on weight bit width
size_bits = self.config.weight_type.size_bits
pack_factor = 32 // size_bits # 8 for 4-bit, 4 for 8-bit
mask = (1 << size_bits) - 1 # 0xF for 4-bit, 0xFF for 8-bit
n_packed, k_groups = packed.shape
n_full = n_packed * pack_factor
# Unpack using vectorized bitwise ops
# shifts = [0, size_bits, 2*size_bits, ...] for each packed position
shifts = torch.arange(
0, 32, size_bits, dtype=torch.int32, device=packed.device
)
# packed: [N//pack_factor, K//G] -> [N//pack_factor, K//G, 1]
# shifts: [pack_factor] -> [1, 1, pack_factor]
# Result: [N//pack_factor, K//G, pack_factor]
unpacked = (packed.unsqueeze(-1) >> shifts) & mask
# Permute to [K//G, N//pack_factor, pack_factor] then reshape to [K//G, N]
unpacked = unpacked.permute(1, 0, 2).reshape(k_groups, n_full)
x.data = unpacked.to(torch.uint8).contiguous()
# Update metadata - zeros are no longer packed
if hasattr(x, "_input_dim"):
x._input_dim = 0
if hasattr(x, "_output_dim"):
x._output_dim = 1
if hasattr(x, "_packed_factor"):
x._packed_factor = 1
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if self.config.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from conch.ops.quantization.gemm import mixed_precision_gemm
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
output = mixed_precision_gemm(
x=x,
w_q_packed=w_q.data,
w_s=w_s.data,
w_zp=w_zp.data if w_zp is not None else None,
weight_size_bits=self.config.weight_type.size_bits,
weight_bias=self.config.weight_type.bias,
group_size=self.config.group_size,
)
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,126 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
unpack_quantized_values_into_int32,
)
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_CPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
class CPUWNA16LinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return -1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "CPUWNA16 only supported on CPU"
if c.weight_type not in _CPUWNA16_SUPPORTED_QUANT_TYPES:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"CPUWNA16, supported types are: "
f"{_CPUWNA16_SUPPORTED_QUANT_TYPES}",
)
if c.group_size != -1 and c.group_size % 2 != 0:
return (
False,
f"Group size ({c.group_size}) not supported by "
"CPUWNA16, supported group sizes are multiples of 2",
)
if c.partition_weight_shape[0] % 32 != 0:
return (
False,
f"Input size ({c.partition_weight_shape[0]}) not supported by "
"CPUWNA16, supported sizes are multiples of 32",
)
if c.partition_weight_shape[1] % 32 != 0:
return (
False,
f"Output size ({c.partition_weight_shape[1]}) not supported by "
"CPUWNA16, supported sizes are multiples of 32",
)
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def _process_gptq_weights(self, layer: torch.nn.Module):
packed_weight = layer.qweight.data
bits = self.config.weight_type.mantissa
pack_factor = 32 // bits
p_w_k, p_w_n = packed_weight.size()
input_size = p_w_k * pack_factor
output_size = p_w_n
isa_hint = _get_isa_hint(layer.scales.dtype)
layer.isa_hint = isa_hint
layer.qzeros = None
if not self.config.has_g_idx:
layer.g_idx = None
# convert input dim packed to output dim packed
weight = unpack_quantized_values_into_int32(
packed_weight, self.config.weight_type, 1
).view(p_w_k, p_w_n, pack_factor)
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1)
# make 16 output channel as a block and transpose to the make
# the block contigous
weight = (
weight.view(input_size, -1, 16 // pack_factor)
.permute(1, 0, 2)
.reshape(-1, input_size * 16 // pack_factor)
.contiguous()
)
layer.qweight.data = weight
def process_weights_after_loading(self, layer: torch.nn.Module):
if not self.config.zero_points:
# GPTQ
self._process_gptq_weights(layer)
else:
# AWQ
raise NotImplementedError("AWQ is not supported in CPUWNA16LinearKernel")
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
x = ops.cpu_gemm_wna16(
input=x,
q_weight=layer.qweight,
scales=layer.scales,
zeros=layer.qzeros,
g_idx=layer.g_idx,
bias=bias,
pack_factor=8, # 32 // 4
isa_hint=layer.isa_hint,
)
return x
def _get_isa_hint(dtype: torch.dtype) -> str:
supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,):
return "amx"
else:
return "vec"

View File

@@ -0,0 +1,131 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
convert_bf16_scales_to_fp8,
convert_packed_uint4b8_to_signed_int4_inplace,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class CutlassW4A8LinearKernel(MPLinearKernel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# dynamic per-tok fp8 activation quantization
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "CUTLASS only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
if c.act_type != torch.float8_e4m3fn:
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
if c.has_g_idx:
return False, "Act reordering not supported by CUTLASS W4A8"
if c.zero_points:
return False, "Zero points not supported by CUTLASS W4A8"
if c.weight_type != scalar_types.int4:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"CUTLASS W4A8, only supported int4",
)
if c.group_size != 128:
return False, "Only group_size 128 is supported"
in_features, out_features = c.partition_weight_shape
if in_features % 128 or out_features % 128:
return (
False,
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
)
if c.out_type != torch.bfloat16:
return (
False,
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
)
return True, None
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
torch.cuda.synchronize()
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
x.data = ops.cutlass_pack_scale_fp8(x.data)
return x
w_s = getattr(layer, self.w_s_name)
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
w_s.data = fp8_scales
# register per-channel scales
layer.register_parameter(
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
)
# Encode/reorder weights and pack scales
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
w_ch_s = layer.weight_chan_scale
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
x_2d, act_scales = self.quant_fp8(x_2d)
output = ops.cutlass_w4a8_mm(
a=x_2d,
b_q=w_q,
b_group_scales=w_s,
b_group_size=c.group_size,
a_token_scales=act_scales,
b_channel_scales=w_ch_s,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
# This implementation is for the KleidiAI-accelerated w4a8int quantization
# scheme on Arm CPUs:
# torch.ops.aten._dyn_quant_matmul_4bit performs dynamic quantized matmul
# it takes:
# - int4 weights packed along with bias/scales by
# torch.ops.aten._dyn_quant_pack_4bit_weight
# - float32/bfloat16 activations
# then it leverages KleidiAI ukernels that:
# - dynamically quantize the activations to int8
# - unpack the int4 weights to int8
# - perform int8 x int8 -> int32 matmul
# - dequantize the int32 output to float32/bfloat16 outputs
class Dynamic4bitLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
@classmethod
def get_min_capability(cls) -> int:
return 1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "Only CPU is supported"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return False, f"Unsupported quant type {c.weight_type}"
if (
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
and c.act_type
not in [
torch.float32,
torch.bfloat16,
torch.float16,
]
):
return (
False,
"Dynamic4bitLinearKernel on Arm requires Float32 or"
" BFloat16 or Float16 activations",
)
if c.full_weight_shape[0] % c.group_size != 0:
return (
False,
f"Group size ({c.group_size}) does not evenly divide"
" the number of input features "
f"({c.full_weight_shape[0]})",
)
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
try:
# Attempt to retrieve the operation
_ = torch.ops.aten._dyn_quant_matmul_4bit
except AttributeError:
return (
False,
f"PyTorch {torch.__version__} does not support"
" _dyn_quant_matmul_4bit. Install a newer version",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
packed_weight = getattr(layer, self.w_q_name)
packed_weight = packed_weight.add(8)
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
torch.uint8
)
scales = getattr(layer, self.w_s_name)
block_size = c.group_size
# Handle scaling factors for partitioned weights
if block_size == c.partition_weight_shape[0]:
scales = scales.to(
torch.float32
) # Float32 & Bfloat16 variants requires float32 scales
scales = scales.view(-1, 1) # Channel-wise scales
if layer.bias is not None:
# Float32 & Bfloat16 variants requires float32 bias
replace_parameter(
layer,
"bias",
torch.nn.Parameter(
layer.bias.to(torch.float32), requires_grad=False
),
)
else:
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
scales = scales.to(torch.bfloat16)
# Repack weights as per kernel requirement
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
uint8_packed,
scales,
layer.bias,
block_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
)
setattr(layer, self.w_s_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# PyTorch / KleidiAI kernels natively support the following configs:
# - channelwise with bfloat16 / float32 activations
# - groupwise with float32 activations
# To support:
# - groupwise with bfloat16/float16 activations: we need to upcast
# activations to float32 before matmul and downcast back to bfloat16/float16
# - channelwise with float16 activations, we need to upcast activations to
# float32 before matmul and downcast back to float16
# Note: these activations will be dynamically quantized to int8 by the kernel.
c = self.config
is_groupwise = c.group_size != c.partition_weight_shape[0]
# dtype of activations before they get dynamically quantized to int8
original_pre_quant_act_dtype = x.dtype
pre_quant_act_dtype = original_pre_quant_act_dtype
if (
is_groupwise and pre_quant_act_dtype == torch.bfloat16
) or pre_quant_act_dtype == torch.float16:
pre_quant_act_dtype = torch.float32
x_2d = x.reshape(-1, x.shape[-1])
if pre_quant_act_dtype != original_pre_quant_act_dtype:
x_2d = x_2d.to(pre_quant_act_dtype)
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q = getattr(layer, self.w_q_name)
output = torch.ops.aten._dyn_quant_matmul_4bit(
x_2d,
w_q,
c.group_size,
c.partition_weight_shape[0],
c.partition_weight_shape[1],
).reshape(out_shape)
if pre_quant_act_dtype != original_pre_quant_act_dtype:
output = output.to(original_pre_quant_act_dtype)
return output

View File

@@ -0,0 +1,168 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class ExllamaLinearKernel(MPLinearKernel):
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
# currently untested so not added to the list
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_cuda_alike():
return (
False,
"Exllama is only supported on CUDA and ROCm",
)
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
"Act reordering currently not supported by Exllama, "
"when the input features are partitioned across "
"devices",
)
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
return (
False,
"Output features must be a multiple of the pack "
"factor (32 / num_bits) so that we can correctly "
"pack the zero points",
)
if c.act_type != torch.float16:
return False, "Exllama only supports float16 activations"
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"Exllama, supported types are: "
f"{cls.SUPPORTED_QUANT_TYPES}",
)
if c.full_weight_shape[0] % c.group_size != 0:
return (
False,
f"Group size ({c.group_size}) does not evenly divide"
" the number of input features "
f"({c.full_weight_shape[0]})",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
# For Exllama, we need to set a zero-point tensor if there is not one
if not c.zero_points:
self.w_zp_name = "qzeros"
device = getattr(layer, self.w_q_name).device
groups = c.partition_weight_shape[0] // c.group_size
out_features = c.partition_weight_shape[1]
if c.weight_type.has_bias():
# if the type has a bias we have to create a zeros tensor that
# contains the bias values repeated for each group (-1 due to
# a bug in the original GPTQ checkpoint format leading to
# exllama kernel adding 1 to the zero points during inference)
# Documentation of the bug can be found here:
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
zeros = torch.full(
(groups, out_features),
c.weight_type.bias - 1,
dtype=torch.int32,
device=device,
)
else:
raise NotImplementedError(
"A 0 zero-point is not supported by Exllama due to "
"a bug in the original GPTQ checkpoint format leading to "
"exllama kernel adding 1 to the zero points during "
"inference"
)
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
setattr(
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
)
if c.has_g_idx:
def transform_w_g_idx(x):
# Exllama wants the permutation array instead of the group
# indices
return torch.argsort(x).to(torch.int)
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) # type: ignore
else:
self.w_gidx_name = "g_idx"
empty_g_idx = torch.nn.Parameter(
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
)
setattr(layer, self.w_gidx_name, empty_g_idx)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
assert self.w_gidx_name is not None
g_idx = getattr(layer, self.w_gidx_name)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x_cont = x.data.contiguous()
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
return x_cont
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x.to(dtype=c.act_type)
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,159 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from functools import partial
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
check_machete_supports_shape,
query_machete_supported_group_sizes,
query_machete_supported_quant_types,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_quantized_values_into_int32,
unpack_quantized_values_into_int32,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
# Machete uses CUTLASS, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Machete only supported on CUDA"
if not current_platform.is_device_capability(90):
return False, "Machete requires compute capability of 90 (Hopper)"
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
return (
False,
"Act reordering currently not supported by Machete, "
"when the input features are partitioned across "
"devices",
)
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"Machete, supported types are: "
f"{query_machete_supported_quant_types(c.zero_points)}",
)
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
return (
False,
f"Group size ({c.group_size}) not supported by "
"Machete, supported group sizes are: "
f"{query_machete_supported_group_sizes(c.act_type)}",
)
return check_machete_supports_shape(
c.partition_weight_shape[0], c.partition_weight_shape[1]
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if (
c.act_type in [torch.float16, torch.bfloat16]
and c.partition_weight_shape[0] % 8 == 0
):
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=0
)
x_perm = x_unpacked[perm, :]
x.data = pack_quantized_values_into_int32(
x_perm, c.weight_type, packed_dim=0
)
x.data = ops.machete_prepack_B(
x.data.t().contiguous().t(),
a_type=c.act_type,
b_type=c.weight_type,
group_scales_type=c.act_type,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(
x.data, c.weight_type, packed_dim=1
)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
if c.zero_points:
assert w_zp is not None
else:
w_zp = None
output = ops.machete_mm(
a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,200 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES,
apply_gptq_marlin_linear,
check_marlin_supports_shape,
marlin_act_int8_process_scales,
marlin_is_k_full,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_permute_bias,
marlin_permute_scales,
marlin_sort_g_idx,
marlin_zero_points,
query_marlin_supported_quant_types,
unpack_cols,
)
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
# Marlin uses inline PTX, so it can only be compatible with Nvidia
if not current_platform.is_cuda():
return False, "Marlin only supported on CUDA"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return (
False,
f"Quant type ({c.weight_type}) not supported by"
f" Marlin, supported types are: {quant_types}",
)
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (
False,
f"Group size ({c.group_size}) not supported by "
"Marlin, supported group sizes are: "
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
)
return check_marlin_supports_shape(
c.partition_weight_shape[1], # out_features
c.partition_weight_shape[0], # in_features
c.full_weight_shape[0], # in_features
c.group_size,
)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
if is_a_8bit:
assert c.weight_type == scalar_types.uint4b8, (
"W8A8 is not supported by marlin kernel."
)
if c.act_type == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
getattr(layer, self.w_s_name).data = (
getattr(layer, self.w_s_name).data * 512
)
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace_new(device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(
x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(
x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size,
is_a_8bit=is_a_8bit,
)
if c.group_size == -1:
num_groups = 1
else:
num_groups = c.partition_weight_shape[0] // c.group_size
if c.act_type == torch.int8 and num_groups > 1:
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
layer.register_parameter(
"input_global_scale",
torch.nn.Parameter(input_global_scale, requires_grad=False),
)
else:
layer.input_global_scale = None
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name)
)
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
)
self._transform_param(
layer,
self.w_zp_name,
lambda x: marlin_zero_points(
unpack_cols(
x.t(),
c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1],
),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits,
is_a_8bit=is_a_8bit,
),
)
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=c.act_type,
)

View File

@@ -0,0 +1,88 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.nn.parameter import Parameter
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
class XPUwNa16LinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return -1
@classmethod
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "XPUwNa16 only supported on XPU"
if c.act_type != torch.bfloat16 and c.act_type != torch.float16:
return False, "XPUwNa16 only supports BF16/FP16 activations"
if c.weight_type not in _XPUWNA16_SUPPORTED_QUANT_TYPES:
return (
False,
f"Quant type ({c.weight_type}) not supported by "
"XPUwNa16, supported types are: "
f"{_XPUWNA16_SUPPORTED_QUANT_TYPES}",
)
if c.group_size != -1 and c.group_size % 32 != 0:
return (
False,
f"Group size ({c.group_size}) not supported by "
"XPUwNa16, supported group sizes are multiples of 32",
)
if c.partition_weight_shape[0] % 32 != 0:
return (
False,
f"Input size ({c.partition_weight_shape[0]}) not supported by "
"XPUwNa16, supported sizes are multiples of 32",
)
if c.partition_weight_shape[1] % 32 != 0:
return (
False,
f"Output size ({c.partition_weight_shape[1]}) not supported by "
"XPUWNA16, supported sizes are multiples of 32",
)
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module):
layer.weight_scale.data = layer.weight_scale.t().contiguous()
if self.config.zero_points:
layer.weight_zero_point.data = layer.weight_zero_point.t().contiguous()
else:
weight_zero_point = torch.Tensor([8]).to(torch.int8).to("xpu")
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
if self.config.has_g_idx:
layer.g_idx.data = layer.g_idx.t().contiguous()
else:
layer.g_idx = None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
out = torch.ops._xpu_C.int4_gemm_w4a16(
reshaped_x,
layer.weight_packed.t(),
bias,
layer.weight_scale,
layer.weight_zero_point,
self.config.group_size,
layer.g_idx,
)
return out

View File

@@ -0,0 +1,187 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Generic, TypeVar
import torch
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
)
from vllm.platforms import current_platform
@dataclass
class ScaledMMLinearLayerConfig:
pass
@dataclass
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
is_static_input_scheme: bool
is_channelwise: bool
input_symmetric: bool
@dataclass
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
weight_quant_key: QuantKey
activation_quant_key: QuantKey
out_dtype: torch.dtype | None
_FP8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_scale_ub,
]
_Int8ParamsT = tuple[
torch.Tensor, # weight
torch.Tensor, # weight_scale
torch.Tensor | None, # input_scale,
torch.Tensor | None, # input_zp
torch.Tensor | None, # azp_adj
]
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
raise NotImplementedError
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
assert self.can_implement(c)[0]
assert self.is_supported()[0]
self.config = c
self.layer_param_names = layer_param_names
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
# return a covariant type in the subclass
@abstractmethod
def _get_layer_params(self, layer) -> _ParamsT:
raise NotImplementedError
class FP8ScaledMMLinearKernel(
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
):
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
act_scale_descriptor = c.activation_quant_key.scale
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
)
self.fp8_dtype = current_platform.fp8_dtype()
super().__init__(c, layer_param_names)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def _get_layer_params(self, layer) -> _FP8ParamsT:
w, w_s, x_s, x_s_ub = self.layer_param_names
return (
getattr(layer, w),
getattr(layer, w_s),
getattr(layer, x_s, None),
getattr(layer, x_s_ub, None),
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
fp8_dtype = self.fp8_dtype
maybe_out_dtype = self.config.out_dtype
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_s computed from x.
# If static, layer.input_scale is scalar and x_s is input_scale.
# View input as 2D matrix for fp8 methods
x_2d = x.view(-1, x.shape[-1])
output_shape = [*x.shape[:-1], w.shape[1]]
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
x_2d_q = x_2d
if x.dtype != fp8_dtype:
x_2d_q, x_s = self.quant_fp8(
x_2d,
x_s,
x_s_ub,
)
return self.apply_scaled_mm(
A=x_2d_q,
B=w,
out_dtype=out_dtype,
As=x_s,
Bs=w_s,
bias=bias,
output_shape=output_shape,
)
@abstractmethod
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
raise NotImplementedError
def get_output_padding(self) -> int | None:
return None
class Int8ScaledMMLinearKernel(
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
):
def _get_layer_params(self, layer) -> _Int8ParamsT:
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
return (
getattr(layer, w_q),
getattr(layer, w_s),
getattr(layer, i_s, None),
getattr(layer, i_zp, None),
getattr(layer, azp_adj, None),
)

View File

@@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
AiterInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
CPUInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
CutlassFP8ScaledMMLinearKernel,
CutlassInt8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
FlashInferFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
ChannelWiseTorchFP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
ROCmFP8ScaledMMLinearKernel,
)
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
ScaledMMLinearKernel,
ScaledMMLinearLayerConfig,
)
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
TritonInt8ScaledMMLinearKernel,
)
__all__ = [
"FP8ScaledMMLinearKernel",
"FP8ScaledMMLinearLayerConfig",
"Int8ScaledMMLinearKernel",
"Int8ScaledMMLinearLayerConfig",
"ScaledMMLinearKernel",
"ScaledMMLinearLayerConfig",
"AiterInt8ScaledMMLinearKernel",
"CPUInt8ScaledMMLinearKernel",
"CutlassFP8ScaledMMLinearKernel",
"CutlassInt8ScaledMMLinearKernel",
"FlashInferFP8ScaledMMLinearKernel",
"ChannelWiseTorchFP8ScaledMMLinearKernel",
"PerTensorTorchFP8ScaledMMLinearKernel",
"RowWiseTorchFP8ScaledMMLinearKernel",
"ROCmFP8ScaledMMLinearKernel",
"TritonInt8ScaledMMLinearKernel",
]

View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.platforms import current_platform
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "Requires ROCm."
if compute_capability is not None and compute_capability < 90:
return False, "requires compute capability 90 and above."
try:
import aiter # noqa: F401 # deliberately attempt to import aiter
except Exception:
return False, "requires `aiter` to be installed."
if not rocm_aiter_ops.is_linear_enabled():
return (
False,
"requires setting `VLLM_ROCM_USE_AITER=1` "
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
)
return True, None
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return False, "supports symmetric quantization only."
return True, None
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
`AiterInt8ScaledMMLinearKernel` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
Currently only support per-tensor-per-tensor GEMM
and per-token-per-channel GEMM through AITER
w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
ATIER block scaled GEMM and mix-precision GEMM.
"""
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
assert symmetric, (
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
)
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
assert x_zp is None, (
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
)
out_dtype = x.dtype
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
m = x_q.shape[0] # a
n = w_q.shape[1] # b
per_tensor_scale_a = x_s.numel() == 1
per_tensor_scale_b = w_s.numel() == 1
per_token_scale_a = x_s.numel() == m
per_channel_scale_b = w_s.numel() == n
# @TODO:
# Maybe broadcast the per-tensor-scale into per-channel-scale
# if one of the scale is a per-channel-scale.
# For now, it only supports:
# - per-tensor-per-tensor a8w8 scaled GEMM, and
# - per-token-per-channel a8w8 scaled GEMM
assert (per_tensor_scale_a and per_tensor_scale_b) or (
per_token_scale_a and per_channel_scale_b
), (
"Currently only support per-tensor-per-tensor GEMM "
" and per-token-per-channel GEMM through AITER"
" w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
"does not support AITER block scaled GEMM."
)
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
# a to be [M, K]
# b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)

View File

@@ -0,0 +1,217 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
from vllm.platforms import current_platform
from vllm.platforms.interface import CpuArchEnum
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cpu():
return False, "requires CPU."
return True, None
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q_name, _, _, _, _ = self.layer_param_names
weight = getattr(layer, w_q_name)
dtype = weight.dtype
N, K = weight.size()
if (
current_platform.get_cpu_architecture() == CpuArchEnum.X86
and envs.VLLM_CPU_SGL_KERNEL
and self.config.input_symmetric
and check_cpu_sgl_kernel(N, K, dtype)
):
self.linear_method = self._apply_weights_sgl
self.process_weights_for_sgl(layer)
else:
self.linear_method = self._apply_weights_onednn
self.process_weights_for_onednn(layer)
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Transpose to [K, N] for convenience
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# oneDNN kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if self.config.input_symmetric:
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
else:
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
azp = (
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
)
replace_parameter(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
# Different from cutlass, oneDNN kernels only need the AZP adjustment
# term for dynamic quantization. And s_b should be folded into the
# term. Such as:
# s_a * s_b * [(A - zp_a)B] + bias =
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
# s_a * GEMM_output - s_a * zp_a * adj + bias
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
weight = getattr(layer, w_q_name)
weight_scale = getattr(layer, w_s_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
azp_adj = azp_adj * weight_scale.squeeze()
setattr(
layer,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
weight = getattr(layer, w_q_name)
self.dnnl_handler = ops.create_onednn_scaled_mm(
weight,
getattr(layer, w_s_name),
torch.get_default_dtype(),
getattr(layer, i_s_name) is None,
not self.config.input_symmetric,
32,
)
# weight is prepacked and maintained by the dnnl_handler,
# release the original weight
setattr(layer, w_q_name, None)
del weight
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, _, _, _ = self.layer_param_names
# WEIGHT
weight = getattr(layer, w_q_name)
packed_weight = torch.ops._C.convert_weight_packed(weight)
replace_parameter(
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
)
if layer.bias is not None:
bias = layer.bias
layer.register_parameter(
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
)
# WEIGHT SCALE
# CPU SGL kernels only support per-channel.
# For per-tensor quant, convert to the per-channel case.
weight_scale = getattr(layer, w_s_name)
if not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return self.linear_method(
layer,
x,
bias,
)
def _apply_weights_onednn(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
x_shape = x.shape
x = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
x, i_s, i_zp, self.config.input_symmetric
)
m = x.size(0)
n = self.dnnl_handler.n
out = torch.empty((m, n), dtype=x.dtype)
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
out = out.reshape(x_shape[:-1] + (n,)) if len(x_shape) > 2 else out
return out
def _apply_weights_sgl(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_layer_params(layer)
return torch.ops._C.int8_scaled_mm_with_quant(
x,
w_q,
w_s,
layer.bias_fp32 if bias is not None else None,
x.dtype,
True,
)

View File

@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
Int8ScaledMMLinearKernel,
Int8ScaledMMLinearLayerConfig,
)
import vllm.envs as envs
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
return True, None
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
config = self.config
# WEIGHT
# Cutlass kernels need transposed weight.
weight = getattr(layer, w_q_name)
replace_parameter(
layer,
w_q_name,
# torch.nn.Parameter(weight.t().data, requires_grad=False),
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Cutlass kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if config.is_static_input_scheme:
input_scale = getattr(layer, i_s_name)
if config.input_symmetric:
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False),
)
setattr(layer, i_zp_name, None)
else:
input_zero_point = getattr(layer, i_zp_name)
# reconstruct the ranges
int8_traits = torch.iinfo(torch.int8)
azps = input_zero_point.to(dtype=torch.int32)
range_max = (input_scale * (int8_traits.max - azps)).max()
range_min = (input_scale * (int8_traits.min - azps)).min()
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
replace_parameter(
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
)
# AZP loaded as int8 but used as int32
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
replace_parameter(
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
)
# azp_adj is the AZP adjustment term, used to account for weights.
# It does not depend on scales or azp, so it is the same for
# static and dynamic quantization.
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
if not config.input_symmetric:
weight = getattr(layer, w_q_name)
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
if config.is_static_input_scheme:
# cutlass_w8a8 requires azp to be folded into azp_adj
# in the per-tensor case
azp_adj = getattr(layer, i_zp_name) * azp_adj
setattr(
layer,
azp_adj_name,
torch.nn.Parameter(azp_adj, requires_grad=False),
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
# ops.scaled_int8_quant supports both dynamic and static quant:
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=symmetric
)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
static = i_zp is not None
azp = None if static else x_zp
return ops.cutlass_scaled_mm_azp(
x_q,
w_q,
scale_a=x_s,
scale_b=w_s,
out_dtype=x.dtype,
azp_adj=azp_adj,
azp=azp,
bias=bias,
)
return ops.cutlass_scaled_mm(
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
)
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
return output.view(*output_shape)

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
if not has_flashinfer():
return False, "requires FlashInfer to be installed."
if compute_capability is not None and compute_capability < 100:
return False, "requires compute capability 100 and above."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)

View File

@@ -0,0 +1,217 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
Base class for FP8 linear kernels using Torch.
Each subclass represents a kernel variant for
specific device capabilities and torch versions.
"""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
return False, "requires ROCm, CUDA or CPU."
if compute_capability is not None and compute_capability < 89:
return False, "requires compute capability 89 and above."
return True, None
def get_output_padding(self) -> int | None:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
#
# The perf gain is still relevant as of 16/1/2026
# torch version == 2.9.0. More details in the link below:
# https://github.com/vllm-project/vllm/issues/32269
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
return 17 if pad_output else None
class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
from vllm.platforms.rocm import on_mi3xx
if not on_mi3xx():
return False, "requires MI3xx."
if compute_capability is not None and compute_capability < 94:
return False, "requires compute capability 94 and above."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if c.out_dtype == torch.float16:
# hipblaslt rowwise _scaled_mm only supports BFloat16
return False, "supports BFloat16 output data type only."
if per_tensor_activation_scales or per_tensor_weight_scales:
return False, "cannot be used with per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Note:
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
bias=bias,
)
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales and per_tensor_weight_scales:
return False, "cannot be used with per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A,
B,
scale_a=dummy_tensor,
scale_b=dummy_tensor,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, output_shape[0])
x_scale = torch.narrow(As, 0, 0, output_shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * Bs.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] <= 4
and B.shape[0] % 16 == 0 # M TODO: needed?
and B.shape[1] % 16 == 0 # K
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
B.t(),
A,
out_dtype,
As,
Bs,
num_compute_units(),
bias,
)
# Fallback
else:
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs,
bias=bias,
)
return output
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
from vllm.platforms.rocm import on_mi3xx
if not on_mi3xx():
return False, "requires MI3xx."
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias
)
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)

View File

@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
triton_scaled_mm,
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
)
class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike():
return True, None
return False, "requires ROCm or CUDA."
@classmethod
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric:
return False, "supports symmetric input only."
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w_q, _, i_s, _, _ = self._get_layer_params(layer)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
replace_parameter(
layer,
w_q_name,
torch.nn.Parameter(w_q.t().data, requires_grad=False),
)
# WEIGHT SCALE
# Triton kernel supports only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter(
layer,
w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# INPUT SCALE
if self.config.is_static_input_scheme:
assert i_s is not None
replace_parameter(
layer,
i_s_name,
torch.nn.Parameter(i_s.max(), requires_grad=False),
)
setattr(layer, i_zp_name, None)
else:
setattr(layer, i_s_name, None)
setattr(layer, i_zp_name, None)
setattr(layer, azp_adj_name, None)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True
)
assert x_zp is None, "Triton kernel only supports symmetric quantization"
return triton_scaled_mm(
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
)

View File

@@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
import torch
from vllm.model_executor.kernels.linear import ( # noqa: E501
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from vllm.platforms import current_platform
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "XPUFP8ScaledMM only support on XPU"
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
return True, None
def __init__(
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
) -> None:
assert self.can_implement(c)[0]
assert self.is_supported()[0]
self.config = c
self.layer_param_names = layer_param_names
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight = layer.weight
weight_scale = layer.weight_scale
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
pass

View File

View File

@@ -0,0 +1,708 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Custom activation functions."""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.collection_utils import LazyDict
logger = init_logger(__name__)
@triton.jit
def _swiglustep_and_mul_kernel(
o_ptr,
o_stride,
x_ptr,
x_stride,
limit: tl.constexpr,
d: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
) -> None:
i = tl.program_id(axis=0).to(tl.int64)
j = tl.program_id(axis=1)
o_row_ptr = o_ptr + o_stride * i
x_row_ptr = x_ptr + x_stride * i
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < d
gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
gate_silu = tl.sigmoid(gate) * gate
gate_clamped = tl.minimum(gate_silu, limit)
up_clamped = tl.minimum(tl.maximum(up, -limit), limit)
result = gate_clamped * up_clamped
result = result.to(x_ptr.dtype.element_ty)
tl.store(o_row_ptr + offsets, result, mask=mask)
def swiglustep_and_mul_triton(
output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
):
b, n = input.shape
assert input.ndim == 2
assert n % 2 == 0
d = n // 2
def grid(meta):
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
_swiglustep_and_mul_kernel[grid](
output,
output.stride(0),
input,
input.stride(0),
limit=limit,
d=d,
BLOCK_SIZE=1024,
)
# --8<-- [start:fatrelu_and_mul]
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:fatrelu_and_mul]
def __init__(self, threshold: float = 0.0):
super().__init__()
self.threshold = threshold
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
self.op = ops.fatrelu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
x1 = x[..., :d]
x2 = x[..., d:]
x1 = F.threshold(x1, self.threshold, 0.0)
return x1 * x2
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x, self.threshold)
return out
# --8<-- [start:silu_and_mul]
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:silu_and_mul]
def __init__(self, *, compile_native: bool = True):
super().__init__(compile_native=compile_native)
if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
self.op = ops.silu_and_mul
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops
self.op = ipex_ops.silu_and_mul
elif current_platform.is_cpu():
self._forward_method = self.forward_native
@staticmethod
def forward_native(x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
# --8<-- [start:mul_and_silu]
@CustomOp.register("mul_and_silu")
class MulAndSilu(CustomOp):
"""An activation function for SwiGLU.
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:mul_and_silu]
def __init__(self):
super().__init__()
if current_platform.is_cuda_alike() or current_platform.is_xpu():
# self.op = torch.ops._C.mul_and_silu
from vllm import _custom_ops as ops
self.op = ops.mul_and_silu
elif current_platform.is_cpu():
self._forward_method = self.forward_native
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return x[..., :d] * F.silu(x[..., d:])
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
# --8<-- [start:gelu_and_mul_sparse]
@CustomOp.register("gelu_and_mul_sparse")
class GeluAndMulSparse(CustomOp):
"""An activation function for GeluAndMulSparse.
This activation function is used in Gemma3n. It computes:
up_proj = self.up_proj(x)
gate_proj = self.gate_proj(x)
gate_proj = self._gaussian_topk(gate_proj) # sparsity
activations = self.act_fn(gate_proj) # gelu
down_proj = self.down_proj(activations * up_proj)
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
# --8<-- [end:gelu_and_mul_sparse]
def __init__(self, activation_sparsity: float, approximate: str = "none"):
super().__init__()
# Gelu.
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if current_platform.is_rocm() and approximate == "tanh":
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] Pytorch's native GELU with tanh approximation is currently "
"unstable and produces garbage. Fallback to 'none' approximation."
)
self.approximate = "none"
# Sparsity.
if activation_sparsity == 0.0:
raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
normal_dist = torch.distributions.normal.Normal(0, 1)
self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
"""Get % sparse percentile of the Gaussian distribution."""
# NOTE(rob): for TP>1, we could all-gather to get the means/std.
# But we do not do this because in expectation they are the same
# and in practice the eval scores are good without gathering.
mean = torch.mean(x, dim=-1, keepdim=True)
std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
cutoff_x = mean + std * self.std_multiplier
return nn.functional.relu(x - cutoff_x)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
out = self._gaussian_topk(x[..., :d])
out = F.gelu(out, approximate=self.approximate)
return out * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_native(x)
# --8<-- [start:gelu_and_mul]
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
# --8<-- [end:gelu_and_mul]
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
if approximate == "none":
from vllm import _custom_ops as ops
self.op = ops.gelu_and_mul
elif approximate == "tanh":
from vllm import _custom_ops as ops
self.op = ops.gelu_tanh_and_mul
if current_platform.is_rocm() and approximate == "tanh":
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable "
"with torch.compile. For native implementation, fallback to 'none' "
"approximation. The custom kernel implementation is unaffected."
)
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
# TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
approximate = self.approximate
if current_platform.is_rocm() and approximate == "tanh":
approximate = "none"
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
def extra_repr(self) -> str:
return f"approximate={repr(self.approximate)}"
# --8<-- [start:swigluoai_and_mul]
@CustomOp.register("swigluoai_and_mul")
class SwigluOAIAndMul(CustomOp):
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
# --8<-- [end:swigluoai_and_mul]
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
super().__init__()
self.alpha = alpha
self.limit = limit
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x[..., ::2], x[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
return gated_output
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
return out
def extra_repr(self) -> str:
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
# --8<-- [start:swiglustep_and_mul]
@CustomOp.register("swiglustep_and_mul")
class SwigluStepAndMul(CustomOp):
"""An activation function for SwiGLU with clamping.
Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def __init__(self, limit: float = 7.0):
super().__init__()
if limit is None:
raise ValueError("SwigluStepAndMul requires limit to be set.")
self.limit = limit
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
gate, up = x.chunk(2, dim=-1)
gate = F.silu(gate)
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
return gate * up
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
swiglustep_and_mul_triton(out, x, self.limit)
return out
def extra_repr(self) -> str:
return f"limit={repr(self.limit)}"
# --8<-- [start:gelu_new]
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
# --8<-- [end:gelu_new]
def __init__(self):
super().__init__()
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
# self.op = torch.ops._C.gelu_new
from vllm import _custom_ops as ops
self.op = ops.gelu_new
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
# --8<-- [start:gelu_fast]
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
# --8<-- [end:gelu_fast]
def __init__(self):
super().__init__()
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
self.op = torch.ops._C.gelu_fast
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
# --8<-- [start:quick_gelu]
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
# --8<-- [end:quick_gelu]
def __init__(self):
super().__init__()
if (
current_platform.is_cuda_alike()
or current_platform.is_cpu()
or current_platform.is_xpu()
):
# self.op = torch.ops._C.gelu_quick
from vllm import _custom_ops as ops
self.op = ops.gelu_quick
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return x * torch.sigmoid(1.702 * x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
self.op(out, x)
return out
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
return self.forward_cuda(x)
# --8<-- [start:relu2]
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
"""
# --8<-- [end:relu2]
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return torch.square(F.relu(x))
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# TODO : implement cuda kernels
return self.forward_native(x)
# --8<-- [start:xielu]
@CustomOp.register("xielu")
class XIELU(CustomOp):
"""
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
Otherwise, we emit a single warning and use xIELU Python
"""
# --8<-- [end:xielu]
def __init__(
self,
alpha_p_init: float = 0.8,
alpha_n_init: float = 0.8,
beta: float = 0.5,
eps: float = -1e-6,
dtype: torch.dtype = torch.bfloat16,
with_vector_loads: bool = False,
):
super().__init__()
self.alpha_p = nn.Parameter(
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
0
)
)
self.alpha_n = nn.Parameter(
torch.log(
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
).unsqueeze(0)
)
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
self.with_vector_loads = with_vector_loads
# Temporary until xIELU CUDA fully implemented
self._beta_scalar = float(self.beta.detach().cpu().float().item())
self._eps_scalar = float(self.eps.detach().cpu().float().item())
self._xielu_cuda_obj = None
try:
import xielu.ops # noqa: F401
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
msg = "Using experimental xIELU CUDA."
try:
from torch._dynamo import allow_in_graph
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
msg += " Enabled torch._dynamo for xIELU CUDA."
except Exception as err:
msg += (
f" Could not enable torch._dynamo for xIELU ({err}) - "
"this may result in slower performance."
)
self._xielu_cuda_fn = self._xielu_cuda
logger.warning_once(msg)
except Exception as err:
logger.warning_once(
"CUDA-fused xIELU not available (%s) "
" falling back to a Python version.\n"
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
str(err),
)
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
alpha_p = nn.functional.softplus(self.alpha_p)
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
return torch.where(
x > 0,
alpha_p * x * x + self.beta * x,
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
)
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
"""Firewall function to prevent torch.compile from seeing .item()"""
assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
original_shape = x.shape
# CUDA kernel expects 3D tensors, reshape if needed
while x.dim() < 3:
x = x.unsqueeze(0)
if x.dim() > 3:
x = x.view(-1, 1, x.size(-1))
if original_shape != x.shape:
logger.warning_once(
"Warning: xIELU input tensor expects 3 dimensions"
" but got (shape: %s). Reshaping to (shape: %s).",
original_shape,
x.shape,
)
result = self._xielu_cuda_obj.forward(
x,
self.alpha_p,
self.alpha_n,
# Temporary until xIELU CUDA fully implemented ->
# self.{beta,eps}.item()
self._beta_scalar,
self._eps_scalar,
self.with_vector_loads,
)
return result.view(original_shape)
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
if self._xielu_cuda_obj is not None and input.is_cuda:
if not torch._dynamo.is_compiling():
return self._xielu_cuda_fn(input)
else:
logger.warning_once(
"torch._dynamo is compiling, using Python version of xIELU."
)
return self._xielu_python(input)
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
return self.forward_native(input)
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: torch.dtype | None = None,
):
super().__init__()
self.act = act_module
self.input_is_parallel = input_is_parallel
if input_is_parallel:
tp_size = get_tensor_model_parallel_world_size()
intermediate_size_per_partition = divide(intermediate_size, tp_size)
else:
intermediate_size_per_partition = intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
)
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(x) / self.scales
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if self.input_is_parallel:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param_data.shape[0]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = LazyDict(
{
"gelu": lambda: nn.GELU(),
"gelu_fast": lambda: FastGELU(),
"gelu_new": lambda: NewGELU(),
"gelu_pytorch_tanh": lambda: (
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
),
nn.GELU(approximate="none"),
)[1]
if current_platform.is_rocm()
else nn.GELU(approximate="tanh"),
"relu": lambda: nn.ReLU(),
"relu2": lambda: ReLUSquaredActivation(),
"silu": lambda: nn.SiLU(),
"quick_gelu": lambda: QuickGELU(),
"tanh": lambda: nn.Tanh(),
"sigmoid": lambda: nn.Sigmoid(),
"xielu": lambda: XIELU(),
}
)
def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name.startswith("torch.nn.modules."):
activation_name = act_fn_name.split(".")[-1]
if activation_name == "identity":
return nn.Identity()
act_fn_name = activation_name
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_REGISTRY[act_fn_name]
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
{
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"geglu": lambda: GeluAndMul(),
"swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
}
)
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]

View File

@@ -0,0 +1,26 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.layers.attention.attention import Attention
from vllm.model_executor.layers.attention.chunked_local_attention import (
ChunkedLocalAttention,
)
from vllm.model_executor.layers.attention.cross_attention import CrossAttention
from vllm.model_executor.layers.attention.encoder_only_attention import (
EncoderOnlyAttention,
)
from vllm.model_executor.layers.attention.mla_attention import MLAAttention
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.attention.static_sink_attention import (
StaticSinkAttention,
)
__all__ = [
"Attention",
"ChunkedLocalAttention",
"CrossAttention",
"EncoderOnlyAttention",
"MLAAttention",
"MMEncoderAttention",
"StaticSinkAttention",
]

View File

@@ -0,0 +1,733 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, Any
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.vllm import VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.kv_transfer_utils import (
maybe_transfer_kv_layer,
)
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform
from vllm.utils.torch_utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionType,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
SlidingWindowSpec,
)
if TYPE_CHECKING:
from vllm.model_executor.layers.attention import MLAAttention
logger = init_logger(__name__)
def validate_kv_sharing_target(
current_layer_name, target_layer_name, static_forward_context
):
error_msg = (
f"Specified KV sharing target layer for {current_layer_name} "
f"is not valid: target layer {target_layer_name} "
)
if current_layer_name == target_layer_name:
raise ValueError(error_msg + "cannot be the same as the current layer.")
if target_layer_name not in static_forward_context:
from vllm.model_executor.models.utils import extract_layer_index
# If target layer name is not in the static fwd context, it means either
# a) the target layer does not come BEFORE the current layer, or
# b) the target layer is not an Attention layer that exists in the model
current_layer_idx = extract_layer_index(current_layer_name)
target_layer_idx = extract_layer_index(target_layer_name)
if current_layer_idx <= target_layer_idx:
raise ValueError(error_msg + "must come before the current layer.")
else:
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
# Currently KV sharing is only supported between layers of the same type
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
expected = static_forward_context[current_layer_name].attn_type
if target_layer_attn_type != expected:
raise ValueError(
error_msg + f"must be the same type as the current layer ({expected})."
)
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
"""Returns whether the quantization method should load quantized weights."""
return quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod
)
def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
"""Sets default quantization scales for the layer."""
if register_buffer:
layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
else:
layer._k_scale.fill_(1.0)
layer._v_scale.fill_(1.0)
layer._q_scale.fill_(1.0)
layer._prob_scale.fill_(1.0)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
layer._q_scale_float = 1.0
layer._k_scale_float = 1.0
layer._v_scale_float = 1.0
layer._prob_scale_float = 1.0
# Initialize q/k/v range constants used by calc_kv_scales
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
def _init_kv_cache_quant(
layer: nn.Module,
quant_config: QuantizationConfig | None,
prefix: str,
) -> None:
"""Initializes KV cache scaling factors and quantization method.
This helper function sets up the KV cache quantization attributes that are
shared between Attention and MLAAttention layers. It initializes scale
tensors for query, key, value, and probability, and configures the
quantization method if applicable.
Args:
layer: The attention layer instance to initialize.
quant_config: Optional quantization configuration.
prefix: Layer name prefix for quantization method lookup.
"""
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
# Note [Register q/k/v/prob scales in state dict]
# When calling model.to(device), only parameters/buffers in state dict are
# moved. If not registering q/k/v/prob scales in state dict, there would
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
# on cpu.
# Registering in state dict means it interacts with weight loading. One edge
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
# (i.e., should_load_quant_weights(quant_method) == False).
# In this case, the checkpoint does not have the scales. We need to
# initialize the scales to 1.0 and update the scales after weight loading.
# This is espectially important when we load dummy weights first (providing
# wrong scales) and then load real weights (which misses scales and keeps the
# wrong scales from dummy load).
set_default_quant_scales(layer, register_buffer=True)
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
layer._o_scale_float = None
quant_method = (
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
)
# See [Note: Register q/k/v/prob scales in state dict]
if should_load_quant_weights(quant_method):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if layer.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
layer.quant_method = quant_method
layer.quant_method.create_weights(layer)
class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
use_alibi_sqrt: bool | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
logits_soft_cap: float | None = None,
per_layer_sliding_window: int | None = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
head_size_v: int | None = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
if kv_cache_scheme is not None:
kv_cache_dtype = "fp8"
calculate_kv_scales = False
if cache_config is not None:
cache_config.cache_dtype = "fp8"
cache_config.calculate_kv_scales = False
# Check if per-head quant scales are required based on kv_cache_scheme
use_per_head_quant_scales = (
kv_cache_scheme is not None
and kv_cache_scheme.get("strategy") == "attn_head"
)
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, (
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
)
self.quant_config = quant_config
self.layer_name = prefix
self.num_heads = num_heads
self.head_size = head_size
self.head_size_v = self.head_size if head_size_v is None else head_size_v
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
# NOTE: model_config may be None during certain tests
model_config = vllm_config.model_config
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
self.attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=False,
has_sink=self.has_sink,
use_mm_prefix=self.use_mm_prefix,
use_per_head_quant_scales=use_per_head_quant_scales,
attn_type=attn_type,
)
else:
self.attn_backend = attn_backend
backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
if use_alibi_sqrt and not backend_supports_alibi_sqrt:
raise ValueError(
f"use_alibi_sqrt is not supported by backend "
f"{self.attn_backend.get_name()}."
)
self.use_alibi_sqrt = bool(use_alibi_sqrt)
if backend_supports_alibi_sqrt:
extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
# prefix caching + batch invariance is currently not supported for
# FLASHINFER and TRITON_MLA.
if (
cache_config is not None
and cache_config.enable_prefix_caching
and vllm_is_batch_invariant()
and (
self.attn_backend.get_name() == "FLASHINFER"
or self.attn_backend.get_name() == "TRITON_MLA"
)
):
logger.warning_once(
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
"with batch invariance, as it is not yet supported.",
scope="local",
)
cache_config.enable_prefix_caching = False
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
num_heads,
head_size,
scale,
num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
**extra_impl_args,
)
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize KV cache quantization attributes
_init_kv_cache_quant(self, quant_config, prefix)
# for attn backends supporting query quantization
self.query_quant = None
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
"fp8"
):
is_per_head = (
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
)
block_size = self.head_size * self.num_heads // self.num_kv_heads
self.query_quant = QuantFP8(
static=True,
group_shape=GroupShape(-1, block_size)
if is_per_head
else GroupShape.PER_TENSOR,
)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
Attention metadata (`attn_metadata`) is set using a context manager in
the model runner's `execute_model` method. It is accessed via forward
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
# which reduces overheads during decoding.
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
# check if query quantization is supported
if self.impl.supports_quant_query_input:
query, _ = self.query_quant(query, self._q_scale)
if self.use_output:
if output_shape is None:
# Handle both 2D [num_tokens, hidden] and
# 3D [num_tokens, heads, head_dim] query
num_tokens = query.shape[0]
output_shape = torch.Size(
(num_tokens, self.num_heads * self.head_size_v)
)
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size_v)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size_v)
kv_cache_dummy_dep = None
if self.use_direct_call:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = unified_kv_cache_update(
key, value, self.layer_name
)
unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
else:
# Skip this if sharing KV cache with an earlier attention layer.
if (
not self.attn_backend.forward_includes_kv_cache_update
and self.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
key, value, self.layer_name
)
torch.ops.vllm.unified_attention_with_output(
query,
key,
value,
output,
self.layer_name,
kv_cache_dummy_dep=kv_cache_dummy_dep,
)
return output.view(-1, hidden_size)
else:
assert self.attn_backend.forward_includes_kv_cache_update, (
"Split KV cache update not supported when output tensor not provided."
)
if self.use_direct_call:
return unified_attention(query, key, value, self.layer_name)
else:
return torch.ops.vllm.unified_attention(
query, key, value, self.layer_name
)
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
s += f", num_heads={self.impl.num_heads}" # type: ignore
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
s += f", scale={self.impl.scale}" # type: ignore
s += f", backend={self.impl.__class__.__name__}"
return s
def process_weights_after_loading(self, act_dtype: torch.dtype):
self.impl.process_weights_after_loading(act_dtype)
# If we should not load quant weights, we initialize the scales to 1.0
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
# for more details.
quant_method = (
self.quant_config.get_quant_method(self, prefix=self.layer_name)
if self.quant_config
else None
)
if not should_load_quant_weights(quant_method):
set_default_quant_scales(self, register_buffer=False)
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
else:
return FullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
dtype=self.kv_cache_torch_dtype,
)
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)
def get_attention_context(
layer_name: str,
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
"""Extract attention context for a given layer.
This helper function extracts the attention metadata, attention layer
instance, KV cache tensor, and slot mapping for a specific layer.
Args:
layer_name: The name/identifier of the attention layer.
Returns:
A tuple containing:
- attn_metadata: Attention metadata for this specific layer, or None if
no metadata available
- attn_layer: The attention layer instance (Attention or MLAAttention)
- kv_cache: The KV cache tensor for current virtual engine
- slot_mapping: The slot mapping for this specific layer
Note: attn_metadata may be None, but attn_layer and kv_cache are always
extracted from the forward context.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
slot_mapping = forward_context.slot_mapping
assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
layer_slot_mapping = slot_mapping.get(layer_name)
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
@maybe_transfer_kv_layer
def unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
return output
def unified_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty_like(query).contiguous()
direct_register_custom_op(
op_name="unified_attention",
op_func=unified_attention,
fake_impl=unified_attention_fake,
)
def unified_kv_cache_update(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
"""
Returns a dummy that is passed to unified_attention to signal a side effect and
the data dependency between them to ensure torch.compile preserves ordering.
"""
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
if layer_slot_mapping is not None:
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer,
key,
value,
kv_cache,
layer_slot_mapping,
)
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def unified_kv_cache_update_fake(
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> torch.Tensor:
return torch.empty(0, device=key.device, dtype=key.dtype)
direct_register_custom_op(
op_name="unified_kv_cache_update",
op_func=unified_kv_cache_update,
fake_impl=unified_kv_cache_update_fake,
mutates_args=[],
)
@maybe_transfer_kv_layer
def unified_attention_with_output(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
# that ensures torch.compile preserves ordering between KV cache update and
# attention forward.
del kv_cache_dummy_dep
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
def unified_attention_with_output_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output: torch.Tensor,
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
kv_cache_dummy_dep: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="unified_attention_with_output",
op_func=unified_attention_with_output,
mutates_args=["output", "output_block_scale"],
fake_impl=unified_attention_with_output_fake,
)

View File

@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
AttentionSpec,
ChunkedLocalAttentionSpec,
KVCacheSpec,
)
@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
attention_chunk_size: int,
block_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
underlying_builder = underlying_attn_backend.get_builder_cls()
assert issubclass(underlying_builder, AttentionMetadataBuilder)
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
@classmethod
def get_cudagraph_support(
cls: type["AttentionMetadataBuilder"],
vllm_config: VllmConfig,
kv_cache_spec: AttentionSpec,
) -> AttentionCGSupport:
# Explicit override in case the underlying builder specialized this getter.
# @override omitted only because of mypy limitation due to type variable.
return AttentionCGSupport.NEVER
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
):
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
attention_chunk_size, common_attn_metadata, block_size
)
metadata = super().build(common_prefix_len, cm, fast_build)
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
return metadata
def update_block_table(
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
):
blk_table = metadata.make_virtual_batches_block_table(blk_table)
return super().update_block_table(metadata, blk_table, slot_mapping)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=ChunkedLocalAttentionBuilder,
)
return attn_backend
class ChunkedLocalAttention(Attention):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
attention_chunk_size: int,
num_kv_heads: int | None = None,
alibi_slopes: list[float] | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
kv_sharing_target_layer_name: str | None = None,
prefix: str = "",
):
self.attention_chunk_size = attention_chunk_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_chunked_local_attention_backend(
underlying_attn_backend, attention_chunk_size, block_size
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
assert self.attention_chunk_size
return ChunkedLocalAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
attention_chunk_size=self.attention_chunk_size,
)

View File

@@ -0,0 +1,226 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
import numpy as np
import torch
from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend_with_overrides,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__)
def _get_cross_slot_mapping(
encoder_seq_lens: np.ndarray,
block_table_tensor: torch.Tensor,
kv_cache_spec: CrossAttentionSpec,
device: torch.device,
) -> torch.Tensor:
"""Get cross-attention slot mappings."""
block_size = kv_cache_spec.block_size
slot_mappings = []
# Find indices with non-zero encoder sequence lengths
# The majority of parallel requests will be running the
# decoder, so this list should be relatively small.
active_indices = np.nonzero(encoder_seq_lens)[0]
for req_index in active_indices:
encoder_seq_len = encoder_seq_lens[req_index].item()
# Calculate the number of blocks needed for this request
num_blocks_needed = cdiv(encoder_seq_len, block_size)
# Get the block IDs for this request from the tensor
req_block_ids = block_table_tensor[req_index]
# Get only the blocks we need (first num_blocks_needed blocks)
needed_block_ids = req_block_ids[:num_blocks_needed]
# All needed blocks are allocated
i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device)
block_indices = i_values // block_size
block_offsets = i_values % block_size
block_numbers = needed_block_ids[block_indices]
slot_mapping = block_numbers * block_size + block_offsets
slot_mappings.append(slot_mapping)
if slot_mappings:
return torch.cat(slot_mappings)
else:
return torch.empty(0, dtype=torch.int64, device=device)
@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
underlying_impl = underlying_attn_backend.get_impl_cls()
class CrossAttentionBuilder(underlying_builder): # type: ignore
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
num_cache_decodes = (
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
)
if num_cache_decodes > 0:
# CrossAttn KV cache has already been populated on first decoder step,
# skip slot_mapping calculation for requests that do not need
# reshape_and_cache.
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
new_metadata.encoder_seq_lens_cpu = np.where(
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
)
# seq_lens is provided by model runner: initial encoder input length is
# needed here to know how many tokens to attend to from the cached
# cross-attention KV cache.
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
new_metadata._seq_lens_cpu = torch.from_numpy(
common_attn_metadata.encoder_seq_lens_cpu
)
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
slot_mapping = _get_cross_slot_mapping(
new_metadata.encoder_seq_lens_cpu,
new_metadata.block_table_tensor,
self.kv_cache_spec,
self.device,
)
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping
return attn_metadata
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
# (gpu_model_runner)
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
def forward(
self,
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if (
not underlying_attn_backend.forward_includes_kv_cache_update
and attn_metadata is not None
and layer.kv_sharing_target_layer_name is None
and key is not None
and value is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
)
return super().forward(
layer,
query,
key,
value,
kv_cache,
attn_metadata,
output,
output_scale,
output_block_scale,
)
attn_backend = subclass_attention_backend_with_overrides(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
overrides={
"get_builder_cls": lambda: CrossAttentionBuilder,
"get_impl_cls": lambda: CrossAttentionImpl,
"forward_includes_kv_cache_update": True,
},
)
return attn_backend
class CrossAttention(Attention):
"""
Cross-attention for encoder-decoder models.
Handles attention between decoder queries and encoder keys/values.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
cache_config: CacheConfig | None = None,
attn_type: str | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_DECODER, (
"CrossAttention only supports AttentionType.ENCODER_DECODER"
)
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_DECODER,
)
attn_backend = create_cross_attention_backend(underlying_attn_backend)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_DECODER,
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return CrossAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)

View File

@@ -0,0 +1,101 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from copy import copy
import torch
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.attention import Attention
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import KVCacheSpec
@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend,
) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
new_common_attn_metadata = copy(common_attn_metadata)
new_common_attn_metadata.causal = False
return super().build(
common_prefix_len, new_common_attn_metadata, fast_build
)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=EncoderOnlyAttentionBuilder,
)
return attn_backend
class EncoderOnlyAttention(Attention):
"""
Encoder attention is a special case that doesn't need a KV Cache.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
cache_config: CacheConfig | None = None,
attn_type: str | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
underlying_attn_backend = get_attn_backend(
head_size,
dtype,
kv_cache_dtype,
block_size,
attn_type=AttentionType.ENCODER_ONLY,
)
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
if attn_type is not None:
assert attn_type == AttentionType.ENCODER_ONLY, (
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
)
super().__init__(
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
attn_type=AttentionType.ENCODER_ONLY,
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Does not need KV cache
return None

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group,
)
def maybe_transfer_kv_layer(func: Callable) -> Callable:
"""Decorator that handles KV layer transfer prior and after execution of
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
On entry: waits for the KV layer from the connector.
On exit: saves the KV layer to the connector.
"""
# Import at runtime to avoid circular dependency
from vllm.model_executor.layers.attention.attention import get_attention_context
# Inspect the signature ONCE when the decorator is applied.
sig = inspect.signature(func)
param_names = list(sig.parameters.keys())
# Find the index of 'layer_name' parameter.
try:
layer_name_index = param_names.index("layer_name")
except ValueError as e:
raise TypeError(
f"Function {func.__name__} must have a 'layer_name' parameter"
) from e
@wraps(func)
def wrapper(*args, **kwargs):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
return func(*args, **kwargs)
layer_name: str = args[layer_name_index]
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
connector = get_kv_transfer_group()
if attn_metadata is None or not connector.has_connector_metadata():
return func(*args, **kwargs)
# Wait for KV layer on entry
connector.wait_for_layer_load(layer_name)
# Execute the function
result = func(*args, **kwargs)
# Save KV cache layer on exit
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
return result
return wrapper

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.vit_attn_wrappers import (
vit_flash_attn_wrapper,
vit_torch_sdpa_wrapper,
vit_triton_attn_wrapper,
)
logger = init_logger(__name__)
# --8<-- [start:mm_encoder_attn]
@CustomOp.register("mm_encoder_attn")
class MMEncoderAttention(CustomOp):
"""Multi-headed attention without any cache, used for multimodal encoder."""
# --8<-- [end:mm_encoder_attn]
def __init__(
self,
num_heads: int,
head_size: int,
scale: float | None = None,
num_kv_heads: int | None = None,
prefix: str = "",
) -> None:
"""
Args:
num_heads: number of attention heads per partition.
head_size: hidden_size per attention head.
scale: scale factor.
num_kv_heads: number of kv heads.
prefix: This has no effect, it is only here to make it easier to
swap between Attention and MultiHeadAttention
"""
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.layer_name = prefix
assert self.num_heads % self.num_kv_heads == 0, (
f"num_heads ({self.num_heads}) is not "
f"divisible by num_kv_heads ({self.num_kv_heads})"
)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
# Get device-specific vision attention backend.
self.attn_backend = get_vit_attn_backend(
head_size=head_size,
dtype=dtype,
)
self.is_flash_attn_backend = self.attn_backend in {
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.ROCM_AITER_FA,
}
self._fa_version = (
get_flash_attn_version() if self.is_flash_attn_backend else None
)
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
@classmethod
def enabled(cls) -> bool:
return True
def view_qkv_to_4d(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
bsz: int,
q_len: int,
kv_len: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Reshape query, key, value to 4D tensors:
(batch_size, seq_len, num_heads, head_size)
"""
query = query.view(bsz, q_len, self.num_heads, self.head_size)
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
return query, key, value
def _forward_sdpa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_torch_sdpa_wrapper(
q=query,
k=key,
v=value,
scale=self.scale,
cu_seqlens=cu_seqlens,
enable_gqa=self.num_heads > self.num_kv_heads,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output
def _forward_fa(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_flash_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
fa_version=self._fa_version,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output
def _forward_triton(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
"""Input shape:
(batch_size x seq_len x hidden_size) or
(batch_size x seq_len x num_heads x head_size)
"""
assert (cu_seqlens is not None and max_seqlen is not None) or (
cu_seqlens is None and max_seqlen is None
), "cu_seqlens and max_seqlen should be both set or both None."
bsz, q_len = query.size()[:2]
kv_len = key.size(1)
is_reshaped = query.dim() != 4
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
output = vit_triton_attn_wrapper(
q=query,
k=key,
v=value,
batch_size=bsz,
scale=self.scale,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
if is_reshaped:
output = output.reshape(bsz, q_len, -1)
return output
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.is_flash_attn_backend:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for CUDA: "
f"{self.attn_backend}."
)
def forward_cpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
return self._forward_sdpa(query, key, value, cu_seqlens)
def forward_xpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
return self._forward_sdpa(query, key, value, cu_seqlens)
else:
raise ValueError(
f"Unsupported multi-modal encoder attention backend for XPU: "
f"{self.attn_backend}."
)

View File

@@ -0,0 +1,252 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
import torch
from vllm.config import CacheConfig, VllmConfig
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.attention import Attention
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
triton_reshape_and_cache_flash_diffkv,
)
from vllm.v1.attention.selector import get_attn_backend
from vllm.v1.kv_cache_interface import (
AttentionSpec,
KVCacheSpec,
SinkFullAttentionSpec,
)
logger = init_logger(__name__)
@functools.lru_cache
def create_static_sink_attention_backend(
underlying_attn_backend: type[AttentionBackend],
sink_len: int = 0,
) -> type[AttentionBackend]:
prefix = "StaticSink_"
underlying_builder = underlying_attn_backend.get_builder_cls()
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
def __init__(
self,
kv_cache_spec: AttentionSpec,
layer_names: list[str],
vllm_config: VllmConfig,
device: torch.device,
):
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
model_config = vllm_config.model_config
scheduler_config = vllm_config.scheduler_config
self.sink_len = sink_len
self.block_size = vllm_config.cache_config.block_size
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
self.max_num_blocks = cdiv(
model_config.max_model_len, vllm_config.cache_config.block_size
)
self.block_table_with_sink = torch.zeros(
(
scheduler_config.max_num_seqs,
self.max_num_blocks + self.num_sink_blocks,
),
device=device,
dtype=torch.int32,
)
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
1,
self.num_sink_blocks + 1,
device=device,
dtype=torch.int32,
)
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AttentionMetadata:
common_attn_metadata.seq_lens[:] = (
common_attn_metadata.seq_lens + self.sink_len
)
common_attn_metadata.seq_lens[
common_attn_metadata.seq_lens == self.sink_len
] = 0
common_attn_metadata.max_seq_len = (
common_attn_metadata.max_seq_len + self.sink_len
)
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
num_reqs = common_attn_metadata.num_reqs
self.block_table_with_sink[
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
:num_reqs
]
return super().build(common_prefix_len, common_attn_metadata, fast_build)
attn_backend = subclass_attention_backend(
name_prefix=prefix,
attention_backend_cls=underlying_attn_backend,
builder_cls=StaticSinkAttentionBuilder,
)
return attn_backend
@CustomOp.register("static_sink_attention")
class StaticSinkAttention(Attention, CustomOp):
"""
Attention with static sink tokens
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
sink_len: int,
attn_backend: type[AttentionBackend] | None = None,
cache_config: CacheConfig | None = None,
**kwargs,
):
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
else:
kv_cache_dtype = "auto"
block_size = 16
if attn_backend is not None:
underlying_attn_backend = attn_backend
else:
underlying_attn_backend = get_attn_backend(
head_size, dtype, kv_cache_dtype, block_size
)
attn_backend = create_static_sink_attention_backend(
underlying_attn_backend, # type: ignore[arg-type]
sink_len=sink_len,
)
Attention.__init__(
self=self,
num_heads=num_heads,
head_size=head_size,
scale=scale,
cache_config=cache_config,
attn_backend=attn_backend,
**kwargs,
)
CustomOp.__init__(self)
self.sink_len = sink_len
self.block_size = block_size
self.sink_populated = False
self.sink_key = None
self.sink_value = None
def update_sink_kv(self, sink_key, sink_value) -> None:
self.sink_key = sink_key
self.sink_value = sink_value
def forward_native(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
assert self.sink_key is not None and self.sink_value is not None, (
"sink_key and sink_value have not been prepared"
)
if not self.sink_populated:
forward_context: ForwardContext = get_forward_context()
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
return super().forward(query, key, value, output_shape)
def forward_cuda(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
output_shape: torch.Size | None = None,
) -> torch.Tensor:
return self.forward_native(query, key, value, output_shape)
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def populate_sink_kv(self, self_kv_cache):
sink_kv_slot_mapping = torch.arange(
self.block_size,
self.sink_len + self.block_size,
device=torch.cuda.current_device(),
dtype=torch.long,
)
triton_reshape_and_cache_flash_diffkv(
self.sink_key,
self.sink_value,
self_kv_cache,
sink_kv_slot_mapping,
self.kv_cache_dtype,
self._k_scale,
self._v_scale,
)
# We only populate the sink_key and sink_value once
self.sink_populated = True
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
return SinkFullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
head_size_v=self.head_size_v,
sink_len=self.sink_len,
dtype=self.kv_cache_torch_dtype,
)
def maybe_populate_sink(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if self.sink_populated or self_kv_cache.numel() == 0:
return
self.populate_sink_kv(self_kv_cache)
def maybe_populate_sink_fake(
self_kv_cache: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_populate_sink",
op_func=maybe_populate_sink,
mutates_args=["self_kv_cache"],
fake_impl=maybe_populate_sink_fake,
)

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionBackend, AttentionImpl
from vllm.v1.kv_cache_interface import KVCacheSpec
class AttentionLayerBase(ABC):
"""
Base class for attention-like layers (Attention, Mamba, etc.)
that support the v1 engine.
This provides a common interface for getting attention backends
from different layer types.
"""
impl: "AttentionImpl"
@abstractmethod
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer."""
pass
@abstractmethod
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
"""
Get the KV cache spec for this layer.
May be None if the layer does not need KV cache.
"""
pass

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,262 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Conv Layer Class."""
import math
from typing import Literal
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.custom_op import CustomOp
from vllm.utils.torch_utils import is_torch_equal
class ConvLayerBase(CustomOp):
"""Conv layer base class."""
num_dim: int
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, ...],
stride: int | tuple[int, ...] = 1,
padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
dilation: int | tuple[int, ...] = 1,
groups: int = 1,
bias: bool = True,
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
valid_padding_strings = {"same", "valid"}
if isinstance(padding, str) and padding not in valid_padding_strings:
raise ValueError(
f"Invalid padding string '{padding}'. "
f"Expected one of {valid_padding_strings}."
)
if padding == "same":
padding = (
kernel_size // 2
if isinstance(kernel_size, int)
else tuple(k // 2 for k in kernel_size)
)
elif padding == "valid":
padding = 0
kernel_size = (
(kernel_size,) * self.num_dim
if isinstance(kernel_size, int)
else kernel_size
)
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
if padding == "same" and any(s != 1 for s in stride):
raise ValueError("padding='same' is not supported for strided convolutions")
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self.enable_linear = (
(self.kernel_size == self.stride)
and not any(self.padding)
and self.groups == 1
)
self.input_size = in_channels * math.prod(self.kernel_size)
self.weight = nn.Parameter(
torch.empty(
out_channels,
in_channels // groups,
*kernel_size,
dtype=params_dtype,
),
)
if bias:
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
else:
self.register_parameter("bias", None)
def extra_repr(self) -> str:
s = f"in_channels={self.in_channels}, "
s += f"out_channels={self.out_channels}, "
s += f"kernel_size={self.kernel_size}, "
s += f"stride={self.stride}, "
s += f"padding={self.padding}, "
s += f"bias={self.bias is not None}"
return s
# --8<-- [start:conv2d]
@CustomOp.register("conv2d")
class Conv2dLayer(ConvLayerBase):
"""Conv layer with Conv2d."""
# --8<-- [end:conv2d]
num_dim = 2
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
B, C, H, W = x.shape
K1, K2 = self.kernel_size
H, W = H // K1, W // K2
x = x.unfold(2, K1, K1).unfold(3, K2, K2)
x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
return x
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
x = F.conv2d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, height, width)"""
assert x.dim() == 4
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# By default, we use CUDNN's convolution ops with optimization.
return self._forward_conv(x)
class CausalConv2dLayer(Conv2dLayer):
"""
A causal version of nn.Conv2d where each location in the 2D matrix would
have no access to locations on its right or down
All arguments are the same as nn.Conv2d except padding which should be
set as None
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
*,
params_dtype: torch.dtype | None = None,
) -> None:
if padding is not None:
raise ValueError(
"Argument padding should be set to None for CausalConv2dLayer."
)
self._left_padding: int = kernel_size - 1
self._right_padding: int = stride - 1
padding = 0
super().__init__(
in_channels,
out_channels,
kernel_size,
stride,
padding,
dilation,
groups,
bias,
padding_mode,
params_dtype=params_dtype,
)
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
x = super().forward(x)
return x
# --8<-- [start:conv3d]
@CustomOp.register("conv3d")
class Conv3dLayer(ConvLayerBase):
"""Conv layer with Conv3d."""
# --8<-- [end:conv3d]
num_dim = 3
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
B, C, T, H, W = x.shape
K1, K2, K3 = self.kernel_size
T, H, W = T // K1, H // K2, W // K3
x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)
x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size)
x = F.linear(
x,
self.weight.view(self.out_channels, self.input_size),
self.bias,
)
x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3)
return x
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 5
x = F.conv3d(
x,
self.weight,
self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return x
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
if self.enable_linear:
return self._forward_mulmat(x)
else:
return self._forward_conv(x)
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a
# significant performance regression.
# See: https://github.com/vllm-project/vllm/issues/27406
# and https://github.com/pytorch/pytorch/issues/166122
# By default, we use CUDNN's convolution ops with optimization.
if self.enable_linear and (is_torch_equal("2.9.0") or is_torch_equal("2.9.1")):
return self._forward_mulmat(x)
return self._forward_conv(x)

View File

@@ -0,0 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang

View File

@@ -0,0 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from .chunk import chunk_gated_delta_rule
from .fused_recurrent import fused_recurrent_gated_delta_rule
from .layernorm_guard import RMSNormGated
__all__ = [
"RMSNormGated",
"chunk_gated_delta_rule",
"fused_recurrent_gated_delta_rule",
]

View File

@@ -0,0 +1,219 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
import torch
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
from .chunk_o import chunk_fwd_o
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
def chunk_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
beta=beta,
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
w=w,
u=u,
g=g,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
o = chunk_fwd_o(
q=q,
k=k,
v=v_new,
h=h,
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
elif SUPPRESS_LEVEL >= 3:
return g, o, A, final_state, w, h, v_new
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
@staticmethod
@input_guard
@torch.amp.custom_fwd(device_type="cuda")
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.LongTensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
q = l2norm_fwd(q)
k = l2norm_fwd(k)
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=scale,
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
return o.to(q.dtype), final_state
@torch.compiler.disable
def chunk_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.LongTensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
Args:
q (torch.Tensor):
Queries of shape `[B, T, H, K]`.
k (torch.Tensor):
Keys of shape `[B, T, H, K]`.
v (torch.Tensor):
Values of shape `[B, T, H, V]`.
g (torch.Tensor):
(forget) Gating tensor (in log space!) of shape `[B, T, H]`.
beta (torch.Tensor):
Betas of shape `[B, T, H]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, V, K]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, H, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,
output_final_state=True,
cu_seqlens=cu_seqlens
)
"""
assert q.dtype == k.dtype == v.dtype
assert q.dtype != torch.float32, (
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
)
assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
if q.shape[1] < q.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
if q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
raise ValueError(
f"The number of initial states is expected to be equal to the number of input sequences, "
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
)
if scale is None:
scale = k.shape[-1] ** -0.5
o, final_state = ChunkGatedDeltaRuleFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
output_final_state,
cu_seqlens,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -0,0 +1,344 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp
from .utils import use_cuda_graph
NUM_WARPS = [2, 4, 8, 16]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"USE_GK": lambda args: args["gk"] is not None,
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4]
for num_stages in [2, 3, 4]
for BV in [32, 64]
],
key=["H", "K", "V", "BT"],
use_cuda_graph=use_cuda_graph,
)
@triton.jit(do_not_specialize=["T"])
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
k,
v,
w,
v_new,
g,
gk,
h,
h0,
ht,
cu_seqlens,
chunk_offsets,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
USE_GK: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr,
STORE_FINAL_STATE: tl.constexpr,
SAVE_NEW_VALUE: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_h = i_nh // H, i_nh % H
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
else:
bos, eos = i_n * T, i_n * T + T
NT = tl.cdiv(T, BT)
boh = i_n * NT
# [BV, BK]
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
# calculate offset
h += ((boh * H + i_h) * V * K).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE:
v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V
stride_h = H * V * K
stride_k = Hg * K
stride_w = H * K
if USE_INITIAL_STATE:
h0 = h0 + i_nh * V * K
if STORE_FINAL_STATE:
ht = ht + i_nh * V * K
# load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(
h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(
h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(
h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
)
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
if SAVE_NEW_VALUE:
p_v = tl.make_block_ptr(
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
last_idx = min((i_t + 1) * BT, T) - 1
if USE_G:
m_t = (i_t * BT + tl.arange(0, BT)) < T
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
p_g = tl.make_block_ptr(
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_g = tl.load(p_g, boundary_check=(0,))
b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
b_g_last = exp(b_g_last)
b_h1 *= b_g_last
if K > 64:
b_h2 *= b_g_last
if K > 128:
b_h3 *= b_g_last
if K > 192:
b_h4 *= b_g_last
if USE_GK:
o_k1 = tl.arange(0, 64)
b_gk_last1 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[None, :]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[None, :]
b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.trans(tl.dot(b_k, b_v))
if K > 64:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.trans(tl.dot(b_k, b_v))
if K > 128:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.trans(tl.dot(b_k, b_v))
if K > 192:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.trans(tl.dot(b_k, b_v))
# epilogue
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(
ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(
ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(
ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: torch.Tensor | None = None,
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, V, K)
final_state = (
k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None
)
v_new = torch.empty_like(u) if save_new_value else None
def grid(meta):
return (triton.cdiv(V, meta["BV"]), N * H)
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
k=k,
v=u,
w=w,
v_new=v_new,
g=g,
gk=gk,
h=h,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
chunk_offsets=chunk_offsets,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
)
return h, v_new, final_state

View File

@@ -0,0 +1,183 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
for BK in BKV_LIST
for BV in BKV_LIST
for num_warps in NUM_WARPS
for num_stages in [2, 3, 4]
],
key=["H", "K", "V", "BT"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_fwd_kernel_o(
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
USE_G: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_tg = i_t
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
NT = tl.cdiv(T, BT)
else:
NT = tl.cdiv(T, BT)
i_tg = i_b * NT + i_t
bos, eos = i_b * T, i_b * T + T
# offset calculation
q += (bos * Hg + i_h // (H // Hg)) * K
k += (bos * Hg + i_h // (H // Hg)) * K
v += (bos * H + i_h) * V
o += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * V * K
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_q = tl.make_block_ptr(
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
)
p_k = tl.make_block_ptr(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, tl.trans(b_h))
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)
if USE_G:
g += bos * H + i_h
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_o = b_o * exp(b_g)[:, None]
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_v = tl.make_block_ptr(
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
p_o = tl.make_block_ptr(
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
b_v = tl.load(p_v, boundary_check=(0, 1))
# to fix mma -> mma layout conversion
# already solved by triton v3.2 or higher
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_fwd_o(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5
o = torch.empty_like(v)
def grid(meta):
return (triton.cdiv(V, meta["BV"]), NT, B * H)
chunk_fwd_kernel_o[grid](
q,
k,
v,
h,
g,
o,
cu_seqlens,
chunk_indices,
scale,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
)
return o

View File

@@ -0,0 +1,154 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
@triton.heuristics(
{
"USE_G": lambda args: args["g"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
}
)
@triton.autotune(
configs=[
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
for BK in [32, 64, 128]
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["H", "K", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_scaled_dot_kkt_fwd_kernel(
k,
beta,
g,
A,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
IS_VARLEN: tl.constexpr,
USE_G: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_t = i_t * BT + tl.arange(0, BT)
m_t = o_t < T
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.zeros([BT, BT], dtype=tl.float32)
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K),
(Hg * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = b_k * b_beta[:, None]
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
if USE_G:
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
b_g = tl.load(p_g, boundary_check=(0,))
b_g_diff = b_g[:, None] - b_g[None, :]
b_A = b_A * exp(b_g_diff)
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
b_A = tl.where(m_A, b_A, 0)
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
Compute beta * K * K^T.
Args:
k (torch.Tensor):
The key tensor of shape `[B, T, H, K]`.
beta (torch.Tensor):
The beta tensor of shape `[B, T, H]`.
g (torch.Tensor):
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
cu_seqlens (torch.LongTensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float32`
Returns:
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
"""
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
B, T, Hg, K = k.shape
H = beta.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
k=k,
g=g,
beta=beta,
A=A,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
BT=BT,
)
return A

View File

@@ -0,0 +1,280 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .utils import check_shared_mem, input_guard
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_scalar_kernel(
s,
o,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
BT: tl.constexpr,
REVERSE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
if HEAD_FIRST:
p_s = tl.make_block_ptr(
s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
)
p_o = tl.make_block_ptr(
o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
)
else:
p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
# [BT]
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
b_o = tl.cumsum(b_s, axis=0)
if REVERSE:
b_z = tl.sum(b_s, axis=0)
b_o = -b_o + b_z[None] + b_s
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({"BS": BS}, num_warps=num_warps)
for BS in BS_LIST
for num_warps in [2, 4, 8]
],
key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"],
)
@triton.jit(do_not_specialize=["T"])
def chunk_local_cumsum_vector_kernel(
s,
o,
cu_seqlens,
chunk_indices,
T,
B: tl.constexpr,
H: tl.constexpr,
S: tl.constexpr,
BT: tl.constexpr,
BS: tl.constexpr,
REVERSE: tl.constexpr,
IS_VARLEN: tl.constexpr,
HEAD_FIRST: tl.constexpr,
):
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, BT)
if REVERSE:
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)
else:
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)
if HEAD_FIRST:
p_s = tl.make_block_ptr(
s + (bos * H + i_h * T) * S,
(T, S),
(S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h * T) * S,
(T, S),
(S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
else:
p_s = tl.make_block_ptr(
s + (bos * H + i_h) * S,
(T, S),
(H * S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
p_o = tl.make_block_ptr(
o + (bos * H + i_h) * S,
(T, S),
(H * S, 1),
(i_t * BT, i_s * BS),
(BT, BS),
(1, 0),
)
# [BT, BS]
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
b_o = tl.dot(m_s, b_s, allow_tf32=False)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T = g.shape
else:
B, T, H = g.shape
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
chunk_local_cumsum_scalar_kernel[grid](
g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
)
return g
def chunk_local_cumsum_vector(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
def grid(meta):
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
# keep cumulative normalizer in fp32
# this kernel is equivalent to
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
chunk_local_cumsum_vector_kernel[grid](
g_org,
g,
cu_seqlens,
chunk_indices,
T=T,
B=B,
H=H,
S=S,
BT=BT,
HEAD_FIRST=head_first,
REVERSE=reverse,
)
return g
@input_guard
def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
**kwargs,
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:
warnings.warn(
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
"when head_first=False was specified. "
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
stacklevel=2,
)
if cu_seqlens is not None:
assert g.shape[0] == 1, (
"Only batch size 1 is supported when cu_seqlens are provided"
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
)
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
)
else:
raise ValueError(
f"Unsupported input shape {g.shape}. "
f"which should be (B, T, H, D) if `head_first=False` "
f"or (B, H, T, D) otherwise"
)

View File

@@ -0,0 +1,393 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .op import exp
@triton.heuristics(
{
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
}
)
@triton.jit(do_not_specialize=["N", "T"])
def fused_recurrent_gated_delta_rule_fwd_kernel(
q,
k,
v,
g,
beta,
o,
h0,
ht,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
scale,
N: tl.int64, # num of sequences
T: tl.int64, # num of tokens
B: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
stride_indices_tok: tl.constexpr,
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
IS_VARLEN: tl.constexpr,
IS_CONTINUOUS_BATCHING: tl.constexpr,
IS_SPEC_DECODING: tl.constexpr,
IS_KDA: tl.constexpr,
):
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
if IS_VARLEN:
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int64),
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
)
all = T
T = eos - bos
else:
bos, eos = i_n * T, i_n * T + T
all = B * T
if T == 0:
# no tokens to process for this sequence
return
o_k = i_k * BK + tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
p_q = q + (bos * H + i_h) * K + o_k
p_k = k + (bos * H + i_h) * K + o_k
p_v = v + (bos * HV + i_hv) * V + o_v
if IS_BETA_HEADWISE:
p_beta = beta + (bos * HV + i_hv) * V + o_v
else:
p_beta = beta + bos * HV + i_hv
if not IS_KDA:
p_g = g + bos * HV + i_hv
else:
p_gk = g + (bos * HV + i_hv) * K + o_k
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
else:
i_t = 0
# Load state index and check for PAD_SLOT_ID (-1)
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
tl.int64
)
# Skip if state index is invalid (PAD_SLOT_ID = -1)
if state_idx < 0:
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
p_h0 = h0 + bos * HV * V * K
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BV, BK]
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[None, :])
# [BV]
b_v -= tl.sum(b_h * b_k[None, :], 1)
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
b_v *= b_beta
# [BV, BK]
b_h += b_v[:, None] * b_k[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
if INPLACE_FINAL_STATE:
# Load state index and check for PAD_SLOT_ID (-1)
final_state_idx = tl.load(
ssm_state_indices + i_n * stride_indices_seq + i_t
).to(tl.int64)
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
p_q += H * K
p_k += H * K
p_o += HV * V
p_v += HV * V
if not IS_KDA:
p_g += HV
else:
p_gk += HV * K
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
def fused_recurrent_gated_delta_rule_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3
num_warps = 1
o = q.new_empty(NK, *v.shape)
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
if ssm_state_indices is None:
stride_indices_seq, stride_indices_tok = 1, 1
elif ssm_state_indices.ndim == 1:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
else:
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
grid = (NK, NV, N * HV)
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
q=q,
k=k,
v=v,
g=g,
beta=beta,
o=o,
h0=initial_state,
ht=final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
scale=scale,
N=N,
T=T,
B=B,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
stride_indices_tok=stride_indices_tok,
IS_BETA_HEADWISE=beta.ndim == v.ndim,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
INPLACE_FINAL_STATE=inplace_final_state,
IS_KDA=False,
num_warps=num_warps,
num_stages=num_stages,
)
o = o.squeeze(0)
return o, final_state
class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
q=q.contiguous(),
k=k.contiguous(),
v=v.contiguous(),
g=g.contiguous(),
beta=beta.contiguous(),
scale=scale,
initial_state=initial_state,
inplace_final_state=inplace_final_state,
cu_seqlens=cu_seqlens,
ssm_state_indices=ssm_state_indices,
num_accepted_tokens=num_accepted_tokens,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
return o, final_state
def fused_recurrent_gated_delta_rule(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
g: torch.Tensor,
beta: torch.Tensor = None,
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
Args:
q (torch.Tensor):
queries of shape `[B, T, H, K]`.
k (torch.Tensor):
keys of shape `[B, T, H, K]`.
v (torch.Tensor):
values of shape `[B, T, HV, V]`.
GVA is applied if `HV > H`.
g (torch.Tensor):
g (decays) of shape `[B, T, HV]`.
beta (torch.Tensor):
betas of shape `[B, T, HV]`.
scale (Optional[int]):
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, V, K]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
Whether to store the final state in-place to save memory.
Default: `True`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
ssm_state_indices (Optional[torch.Tensor]):
Indices to map the input sequences to the initial/final states.
num_accepted_tokens (Optional[torch.Tensor]):
Number of accepted tokens for each sequence during decoding.
Returns:
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, V, K]`.
Examples::
>>> import torch
>>> import torch.nn.functional as F
>>> from einops import rearrange
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
# inputs with equal lengths
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
>>> q = torch.randn(B, T, H, K, device='cuda')
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, V, K, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
)
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,
cu_seqlens=cu_seqlens
)
"""
if cu_seqlens is not None and q.shape[0] != 1:
raise ValueError(
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
f"Please flatten variable-length inputs before processing."
)
if scale is None:
scale = k.shape[-1] ** -0.5
else:
assert scale > 0, "scale must be positive"
if beta is None:
beta = torch.ones_like(q[..., 0])
o, final_state = FusedRecurrentFunction.apply(
q,
k,
v,
g,
beta,
scale,
initial_state,
inplace_final_state,
cu_seqlens,
ssm_state_indices,
num_accepted_tokens,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -0,0 +1,41 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import triton
from .utils import tensor_cache
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
return cu_seqlens[1:] - cu_seqlens[:-1]
@tensor_cache
def prepare_chunk_indices(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
indices = torch.cat(
[
torch.arange(n)
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
]
)
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
@tensor_cache
def prepare_chunk_offsets(
cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
return torch.cat(
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
).cumsum(-1)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import os
import torch
from vllm.triton_utils import tl, triton
BT_LIST = [8, 16, 32, 64, 128]
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
],
key=["D"],
)
@triton.jit
def l2norm_fwd_kernel1(
x,
y,
D,
BD: tl.constexpr,
eps,
):
i_t = tl.program_id(0)
x += i_t * D
y += i_t * D
# Compute mean and variance
cols = tl.arange(0, BD)
mask = cols < D
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=0)
b_rstd = 1 / tl.sqrt(b_var + eps)
# tl.store(Rstd + i_t, rstd)
# Normalize and apply linear transformation
b_y = b_x * b_rstd
tl.store(y + cols, b_y, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BT": BT}, num_warps=num_warps)
for num_warps in [1, 2, 4, 8, 16]
for BT in BT_LIST
],
key=["D"],
)
@triton.jit(do_not_specialize=["NB"])
def l2norm_fwd_kernel(
x,
y,
eps,
NB,
T,
D: tl.constexpr,
BT: tl.constexpr,
BD: tl.constexpr,
):
i_t = tl.program_id(0)
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
b_var = tl.sum(b_x * b_x, axis=1)
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
@triton.jit
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
xoffset = tl.program_id(0) * MBLOCK
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
xmask = row_idx < M
rindex = tl.arange(0, N)[None, :]
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
rsqrt = tl.rsqrt(square_sum + eps)
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])
# allocate output
if output_dtype is None:
y = torch.empty_like(x)
else:
y = torch.empty_like(x, dtype=output_dtype)
assert y.stride(-1) == 1
T, D = x.shape[0], x.shape[-1]
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
if D > BD:
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
if not USE_DEFAULT_FLA_NORM:
MBLOCK = 32
# M, N = x.shape
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)](
x,
y,
eps,
T,
D,
MBLOCK,
)
else:
if D <= 512:
NB = triton.cdiv(T, 2048)
def grid(meta):
return (triton.cdiv(T, meta["BT"]),)
l2norm_fwd_kernel[grid](
x,
y,
eps,
NB=NB,
T=T,
D=D,
BD=BD,
)
else:
l2norm_fwd_kernel1[(T,)](
x,
y,
eps=eps,
D=D,
BD=BD,
)
return y.view(x_shape_og)

View File

@@ -0,0 +1,388 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Tri Dao
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2024, Tri Dao.
# ruff: noqa: E501
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv, next_power_of_2
from vllm.utils.platform_utils import num_compute_units
from .utils import input_guard
def rms_norm_ref(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
upcast=True,
):
dtype = x.dtype
weight = weight.float()
bias = bias.float() if bias is not None else None
if upcast:
x = x.float()
z = z.float() if z is not None else z
if z is not None and not norm_before_gate:
x = x * F.silu(z)
if group_size is None:
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
else:
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
if bias is not None:
out = out + bias
if z is not None and norm_before_gate:
out *= F.silu(z)
return out.to(dtype)
@triton.heuristics(
{
"HAS_BIAS": lambda args: args["B"] is not None,
"HAS_Z": lambda args: args["Z"] is not None,
}
)
@triton.jit
def layer_norm_fwd_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
M, # number of rows in X
N: tl.constexpr, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
ROWS_PER_BLOCK: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the starting row of X and Y it should compute.
row_start = tl.program_id(0) * ROWS_PER_BLOCK
group = tl.program_id(1)
# Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N]
rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
cols = tl.arange(0, BLOCK_N)
# Compute offsets for 2D tile
row_offsets = rows[:, None] * stride_x_row
col_offsets = cols[None, :] + group * N
# Base pointers
X_base = X + row_offsets + col_offsets
Y_base = Y + rows[:, None] * stride_y_row + col_offsets
# Create mask for valid rows and columns
row_mask = rows[:, None] < M
col_mask = cols[None, :] < N
mask = row_mask & col_mask
# Load input data with 2D tile
x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
x *= z * tl.sigmoid(z)
# Compute mean and variance per row (reduce along axis 1)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=1) / N # Shape: [ROWS_PER_BLOCK]
# Store mean for each row
mean_offsets = group * M + rows
mean_mask = rows < M
tl.store(Mean + mean_offsets, mean, mask=mean_mask)
# Broadcast mean back to 2D for subtraction
xbar = tl.where(mask, x - mean[:, None], 0.0)
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
else:
xbar = tl.where(mask, x, 0.0)
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
mean = 0.0 # Placeholder for RMS norm
rstd = tl.rsqrt(var + eps) # Shape: [ROWS_PER_BLOCK]
# Store rstd for each row
rstd_offsets = group * M + rows
rstd_mask = rows < M
tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask)
# Load weights and biases (broadcast across rows)
w_offsets = cols + group * N
w_mask = cols < N
w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
# Normalize and apply linear transformation
if not IS_RMS_NORM:
x_hat = (x - mean[:, None]) * rstd[:, None]
else:
x_hat = x * rstd[:, None]
y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :]
if HAS_Z and NORM_BEFORE_GATE:
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y_base, y, mask=mask)
def calc_rows_per_block(M: int, device: torch.device) -> int:
sm_count = num_compute_units(device.index)
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
rows_per_block = min(rows_per_block, 4)
return rows_per_block
def layer_norm_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
eps: float,
z: torch.Tensor = None,
out: torch.Tensor = None,
group_size: int = None,
norm_before_gate: bool = True,
is_rms_norm: bool = False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
# Calculate rows per block based on SM count
rows_per_block = calc_rows_per_block(M, x.device)
# Update grid to use rows_per_block
grid = (cdiv(M, rows_per_block), ngroups)
layer_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
ROWS_PER_BLOCK=rows_per_block,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
class LayerNormFn(torch.autograd.Function):
@input_guard
@staticmethod
def forward(
ctx,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.group_size = group_size
ctx.norm_before_gate = norm_before_gate
ctx.is_rms_norm = is_rms_norm
return y.reshape(x_shape_og)
def layernorm_fn(
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNormGated(nn.Module):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: int | None = None,
norm_before_gate: bool = True,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
torch.nn.init.zeros_(self.bias)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return layernorm_fn(
x,
self.weight,
self.bias,
z=z,
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
)
class RMSNormGated(nn.Module):
def __init__(
self,
hidden_size,
eps: float = 1e-5,
group_size: int | None = None,
norm_before_gate: bool = False,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
self.register_parameter("bias", None)
self.group_size = group_size
self.norm_before_gate = norm_before_gate
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.ones_(self.weight)
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
x,
self.weight,
self.bias,
z=z,
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
)

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import os
from vllm.triton_utils import tl, tldevice, triton
from .utils import is_gather_supported
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
exp = tldevice.fast_expf
log = tldevice.fast_logf
log2 = tldevice.fast_log2f
else:
exp = tl.exp
log = tl.log
log2 = tl.log2
if not is_gather_supported:
@triton.jit
def gather(src, index, axis, _builder=None):
"""
Gather operation that works when tl.gather is not supported.
This is a fallback implementation that returns None.
Just to make triton compiler happy.
"""
return None
else:
gather = tl.gather
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
# For Triton 3.3.x
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
elif hasattr(triton.language, "make_tensor_descriptor"):
# For Triton 3.4.x and later
make_tensor_descriptor = triton.language.make_tensor_descriptor
else:
"""
Fallback implementation when TMA is not supported.
Returns None to indicate TMA descriptors are unavailable.
Just make triton compiler happy.
"""
@triton.jit
def make_tensor_descriptor(
base,
shape,
strides,
block_shape,
_builder=None,
):
return None

View File

@@ -0,0 +1,556 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import os
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import make_tensor_descriptor
from .utils import input_guard, is_amd, is_tma_supported
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"]
assert FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS, (
f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}"
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=["BT"],
)
@triton.jit(do_not_specialize=["T"])
def solve_tril_16x16_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A = A + (bos * H + i_h) * BT
Ai = Ai + (bos * H + i_h) * 16
offset = (i_t * 16) % BT
if not USE_TMA:
p_A = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
)
# [16, 16]
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16])
b_A = desc.load([i_t * 16, offset]).to(tl.float32)
b_A = -tl.where(m_A, b_A, 0)
for i in range(2, min(16, T - i_t * 16)):
# [16]
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
b_A += m_I
if not USE_TMA:
p_Ai = tl.make_block_ptr(
Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)
)
tl.store(
p_Ai,
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
else:
desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [1, 2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=["H", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def merge_16x16_to_32x32_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
)
p_A_22 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
)
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
# [16, 16]
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
b_Ai_11 += m_I
b_Ai_22 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
)
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_Ai_21 = -tl.dot(
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
b_Ai_11,
input_precision=DOT_PRECISION,
)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
)
p_Ai_21 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
)
p_Ai_22 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
)
tl.store(
p_Ai_11,
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_22,
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_21,
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
else:
desc_o.store(
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4, 5]
],
key=["H", "BT", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def merge_16x16_to_64x64_inverse_kernel(
A,
Ai,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
BT: tl.constexpr,
USE_TMA: tl.constexpr,
IS_VARLEN: tl.constexpr,
DOT_PRECISION: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
o_i = tl.arange(0, 16)
m_A = o_i[:, None] > o_i[None, :]
m_I = o_i[:, None] == o_i[None, :]
A += (bos * H + i_h) * BT
Ai += (bos * H + i_h) * BT
if not USE_TMA:
p_A_11 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
)
p_A_22 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
)
p_A_33 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
)
p_A_44 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
)
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
else:
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
# [16, 16]
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
for i in range(2, min(16, T - i_t * BT)):
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
for i in range(16 + 2, min(32, T - i_t * BT)):
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
for i in range(32 + 2, min(48, T - i_t * BT)):
b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
for i in range(48 + 2, min(64, T - i_t * BT)):
b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
b_Ai_11 += m_I
b_Ai_22 += m_I
b_Ai_33 += m_I
b_Ai_44 += m_I
if not USE_TMA:
p_A_21 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
)
p_A_31 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
)
p_A_32 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
)
p_A_41 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
)
p_A_42 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
)
p_A_43 = tl.make_block_ptr(
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
)
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
else:
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
b_Ai_21 = -tl.dot(
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
b_Ai_11,
input_precision=DOT_PRECISION,
)
b_Ai_32 = -tl.dot(
tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
b_Ai_22,
input_precision=DOT_PRECISION,
)
b_Ai_43 = -tl.dot(
tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
b_Ai_33,
input_precision=DOT_PRECISION,
)
b_Ai_31 = -tl.dot(
b_Ai_33,
tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION)
+ tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
b_Ai_42 = -tl.dot(
b_Ai_44,
tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION)
+ tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
b_Ai_41 = -tl.dot(
b_Ai_44,
tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
+ tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
+ tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
input_precision=DOT_PRECISION,
)
if not USE_TMA:
p_Ai_11 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
)
p_Ai_22 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
)
p_Ai_33 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
)
p_Ai_44 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
)
p_Ai_21 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
)
p_Ai_31 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
)
p_Ai_32 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
)
p_Ai_41 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
)
p_Ai_42 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
)
p_Ai_43 = tl.make_block_ptr(
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
)
tl.store(
p_Ai_11,
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_22,
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_33,
b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_44,
b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_21,
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_31,
b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_32,
b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_41,
b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_42,
b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
tl.store(
p_Ai_43,
b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
boundary_check=(0, 1),
)
else:
desc_o.store(
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
desc_o.store(
[i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")
)
@input_guard
def solve_tril(
A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
Compute the inverse of the matrix I + A
A should be strictly lower triangular, i.e., A.triu() == 0.
Args:
A (torch.Tensor):
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
Returns:
(I + A)^-1 with the same shape as A
"""
assert A.shape[-1] in [16, 32, 64]
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)
if BT == 16:
merge_fn = solve_tril_16x16_kernel
elif BT == 32:
merge_fn = merge_16x16_to_32x32_inverse_kernel
elif BT == 64:
merge_fn = merge_16x16_to_64x64_inverse_kernel
merge_fn[NT, B * H](
A=A,
Ai=Ai,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
BT=BT,
USE_TMA=is_tma_supported,
DOT_PRECISION=FLA_TRIL_PRECISION,
)
return Ai

View File

@@ -0,0 +1,194 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import contextlib
import functools
import logging
import os
from collections.abc import Callable
from enum import Enum
from typing import Any, Literal
import torch
from vllm.platforms import current_platform
from vllm.triton_utils import triton
logger = logging.getLogger(__name__)
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache_entries: tuple[tuple | None, dict | None, Any] = []
cache_size = 8
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache_entries, cache_size
for i, entry in enumerate(cache_entries):
last_args, last_kwargs, last_result = entry
if (
len(args) == len(last_args)
and len(kwargs) == len(last_kwargs)
and all(a is b for a, b in zip(args, last_args))
and all(
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
)
):
cache_entries = (
cache_entries[:i]
+ cache_entries[i + 1 :]
+ [(args, kwargs, last_result)]
)
return last_result
result = fn(*args, **kwargs)
if len(cache_entries) >= cache_size:
cache_entries = cache_entries[1:]
cache_entries.append((args, kwargs, result))
return result
return wrapper
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
contiguous_args = (
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
)
contiguous_kwargs = {
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
for k, v in kwargs.items()
}
tensor = None
for arg in args:
if isinstance(arg, torch.Tensor):
tensor = arg
break
if tensor is None:
for value in kwargs.values():
if isinstance(value, torch.Tensor):
tensor = value
break
if tensor is not None:
ctx = torch.cuda.device(tensor.device.index)
else:
ctx = contextlib.nullcontext()
with ctx:
return fn(*contiguous_args, **contiguous_kwargs)
return wrapper
@functools.cache
def get_available_device() -> str:
try:
return triton.runtime.driver.active.get_current_target().backend
except (RuntimeError, AttributeError):
return "cpu"
@functools.cache
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
device = get_available_device()
mapping = {
"cuda": "nvidia",
"hip": "amd",
"xpu": "intel",
}
# return the mapped value, or the original if not found
return mapping.get(device, device)
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
device_torch_lib = getattr(torch, device, None)
device_platform = _check_platform()
is_amd = device_platform == "amd"
is_intel = device_platform == "intel"
is_nvidia = device_platform == "nvidia"
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
is_nvidia_hopper = is_nvidia and (
"NVIDIA H" in torch.cuda.get_device_name(0)
or torch.cuda.get_device_capability()[0] >= 9
)
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
is_gather_supported = hasattr(triton.language, "gather")
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
hasattr(triton.language, "_experimental_make_tensor_descriptor")
or hasattr(triton.language, "make_tensor_descriptor")
)
def get_all_max_shared_mem():
try:
return [
triton.runtime.driver.active.utils.get_device_properties(i)[
"max_shared_mem"
]
for i in range(device_torch_lib.device_count())
]
except BaseException:
return [-1]
class Backend(Enum):
ADA = 101376 # RTX 4090
AMPERE = 166912 # A100
HOPPER = 232448 # H100
DEFAULT = 102400 # Default
@classmethod
def get_shared_memory(cls, arch: str) -> int:
try:
return cls[arch.upper()].value
except KeyError:
return cls.DEFAULT.value
@functools.cache
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
try:
device_shared_mem_list = get_all_max_shared_mem()
max_shared_memory = device_shared_mem_list[tensor_idx]
return max_shared_memory >= Backend.get_shared_memory(arch)
except Exception:
return False

View File

@@ -0,0 +1,158 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
#
# This file contains code copied from the flash-linear-attention project.
# The original source code was licensed under the MIT license and included
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
@triton.autotune(
configs=[
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
for num_warps in [2, 4, 8]
for num_stages in [2, 3, 4]
],
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
)
@triton.jit(do_not_specialize=["T"])
def recompute_w_u_fwd_kernel(
k,
v,
beta,
w,
u,
A,
g,
cu_seqlens,
chunk_indices,
T,
H: tl.constexpr,
Hg: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BT: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
i_t, i_bh = tl.program_id(0), tl.program_id(1)
i_b, i_h = i_bh // H, i_bh % H
if IS_VARLEN:
i_n, i_t = (
tl.load(chunk_indices + i_t * 2).to(tl.int32),
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
)
bos, eos = (
tl.load(cu_seqlens + i_n).to(tl.int32),
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
)
T = eos - bos
else:
bos, eos = i_b * T, i_b * T + T
p_beta = tl.make_block_ptr(
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
)
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
p_A = tl.make_block_ptr(
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
)
b_beta = tl.load(p_beta, boundary_check=(0,))
b_A = tl.load(p_A, boundary_check=(0, 1))
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
for i_v in range(tl.cdiv(V, BV)):
p_v = tl.make_block_ptr(
v + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
p_u = tl.make_block_ptr(
u + (bos * H + i_h) * V,
(T, V),
(H * V, 1),
(i_t * BT, i_v * BV),
(BT, BV),
(1, 0),
)
b_v = tl.load(p_v, boundary_check=(0, 1))
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
for i_k in range(tl.cdiv(K, BK)):
p_k = tl.make_block_ptr(
k + (bos * Hg + i_h // (H // Hg)) * K,
(T, K),
(Hg * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
p_w = tl.make_block_ptr(
w + (bos * H + i_h) * K,
(T, K),
(H * K, 1),
(i_t * BT, i_k * BK),
(BT, BK),
(1, 0),
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
b_w = tl.dot(b_A, b_kb)
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
def recompute_w_u_fwd(
k: torch.Tensor,
v: torch.Tensor,
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64
u = torch.empty_like(v)
w = k.new_empty(B, T, H, K)
recompute_w_u_fwd_kernel[(NT, B * H)](
k=k,
v=v,
beta=beta,
w=w,
u=u,
A=A,
g=g_cumsum,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
T=T,
H=H,
Hg=Hg,
K=K,
V=V,
BT=BT,
BK=BK,
BV=BV,
)
return w, u

View File

@@ -0,0 +1,140 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from contextlib import contextmanager
from typing import Any
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
activation_without_mul,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
RoutingMethodType,
)
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase,
)
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
FusedMoEPermuteExpertsUnpermute,
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
FusedMoERouter,
)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
UnquantizedFusedMoEMethod,
)
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
ZeroExpertFusedMoE,
)
from vllm.triton_utils import HAS_TRITON
_config: dict[str, Any] | None = None
@contextmanager
def override_config(config):
global _config
old_config = _config
_config = config
yield
_config = old_config
def get_config() -> dict[str, Any] | None:
return _config
__all__ = [
"FusedMoE",
"FusedMoERouter",
"FusedMoEConfig",
"FusedMoEMethodBase",
"MoEActivation",
"UnquantizedFusedMoEMethod",
"FusedMoeWeightScaleSupported",
"FusedMoEPermuteExpertsUnpermute",
"FusedMoEActivationFormat",
"FusedMoEPrepareAndFinalize",
"RoutingMethodType",
"SharedFusedMoE",
"ZeroExpertFusedMoE",
"activation_without_mul",
"apply_moe_activation",
"override_config",
"get_config",
]
if HAS_TRITON:
# import to register the custom ops
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
CutlassExpertsW4A8Fp8,
cutlass_moe_w4a8_fp8,
)
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.fused_moe import (
TritonExperts,
TritonWNA16Experts,
fused_experts,
get_config_file_name,
)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
AiterExperts,
)
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
fused_topk,
)
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
GroupedTopk,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
XPUExperts,
XPUExpertsFp8,
)
__all__ += [
"AiterExperts",
"fused_topk",
"fused_experts",
"get_config_file_name",
"GroupedTopk",
"cutlass_moe_w4a8_fp8",
"CutlassExpertsFp8",
"CutlassBatchedExpertsFp8",
"CutlassExpertsW4A8Fp8",
"TritonExperts",
"TritonWNA16Experts",
"BatchedTritonExperts",
"DeepGemmExperts",
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"XPUExperts",
"XPUExpertsFp8",
]
else:
# Some model classes directly use the custom ops. Add placeholders
# to avoid import errors.
def _raise_exception(method: str):
raise NotImplementedError(f"{method} is not implemented as lack of triton.")
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")

View File

@@ -0,0 +1,141 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""MoE activation function enum and utilities."""
from enum import Enum
import torch
import torch.nn.functional as F
from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
class MoEActivation(Enum):
"""Activation functions for MoE layers."""
# Gated activations (gate * activation(up)) expect input of shape [..., 2*d]
# and produce output of shape [..., d]
SILU = "silu"
GELU = "gelu"
RELU2 = "relu2"
SWIGLUOAI = "swigluoai"
SWIGLUSTEP = "swiglustep"
# Non-gated activations (no mul with gate) expect input of shape [..., d]
# and produce output of shape [..., d].
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
SILU_NO_MUL = "silu_no_mul"
GELU_NO_MUL = "gelu_no_mul"
RELU2_NO_MUL = "relu2_no_mul"
@property
def is_gated(self) -> bool:
"""Returns True if activation expects gate*activation(up) pattern.
Gated activations expect input tensor with 2x the output size,
where the first half is the gate and second half is the up projection.
"""
return not self.value.endswith("_no_mul")
@property
def custom_op_name(self) -> str:
"""Maps to the CustomOp name of activations
in vllm/model_executor/layers/activation.py."""
return _CUSTOM_OP_NAMES[self]
def without_mul(self) -> "MoEActivation":
"""Get the non-gated variant of this activation.
For activations that have a _no_mul variant, returns that variant.
For activations without a _no_mul variant (or already _no_mul),
returns self.
"""
return _WITHOUT_MUL.get(self, self)
@classmethod
def from_str(cls, s: str) -> "MoEActivation":
"""Parse from string for backward compatibility."""
for member in cls:
if member.value == s:
return member
valid = [m.value for m in cls]
raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
# Module-level lookup tables used by MoEActivation functions.
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
MoEActivation.SILU: "silu_and_mul",
MoEActivation.GELU: "gelu_and_mul",
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
MoEActivation.RELU2: "relu2",
MoEActivation.SILU_NO_MUL: "silu_and_mul",
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
MoEActivation.RELU2_NO_MUL: "relu2",
}
_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
}
def activation_without_mul(activation: str) -> str:
"""Get the non-gated variant of an activation function.
Args:
activation: The activation function name (e.g., "silu", "gelu")
Returns:
The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul")
"""
return MoEActivation.from_str(activation).without_mul().value
def apply_moe_activation(
activation: MoEActivation,
output: torch.Tensor,
input: torch.Tensor,
) -> torch.Tensor:
"""Apply MoE activation function."""
assert input.dim() == 2, "Input must be 2D"
assert output.dim() == 2, "Output must be 2D"
if activation.is_gated:
assert output.size(-1) * 2 == input.size(-1), (
f"{activation.value} expects 2x ratio: "
f"{output.size(-1) * 2} vs {input.size(-1)}"
)
else:
assert output.size(-1) == input.size(-1), (
f"{activation.value} expects equal sizes: "
f"{output.size(-1)} vs {input.size(-1)}"
)
# Activations with gated multiplication (gate × activation(up))
if activation == MoEActivation.SILU:
# torch.ops._C.silu_and_mul(output, input)
silu_and_mul(output, input)
elif activation == MoEActivation.GELU:
# torch.ops._C.gelu_and_mul(output, input)
gelu_and_mul(output, input)
elif activation == MoEActivation.SWIGLUOAI:
# torch.ops._C.swigluoai_and_mul(output, input)
swigluoai_and_mul(output, input)
elif activation == MoEActivation.SWIGLUSTEP:
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
swiglustep_and_mul_triton(output, input)
# Activations without gated multiplication
elif activation == MoEActivation.SILU_NO_MUL:
output.copy_(F.silu(input))
elif activation == MoEActivation.GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == MoEActivation.RELU2_NO_MUL:
F.relu(input, inplace=True)
torch.square(input, out=output)
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
return output

View File

@@ -0,0 +1,254 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.distributed import (
get_ep_group,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
FlashInferA2APrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEPrepareAndFinalize,
)
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNaiveEP,
MoEPrepareAndFinalizeNoEP,
)
from vllm.platforms import current_platform
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
logger = init_logger(__name__)
if current_platform.is_cuda_alike():
if has_pplx():
from .pplx_prepare_finalize import (
PplxPrepareAndFinalize,
pplx_hidden_dim_scale_bytes,
)
if has_deep_ep():
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
from .deepep_ll_prepare_finalize import (
DEEPEP_QUANT_BLOCK_SHAPE,
DeepEPLLPrepareAndFinalize,
)
if has_mori():
from .mori_prepare_finalize import MoriPrepareAndFinalize
def maybe_roundup_layer_hidden_size(
hidden_size: int,
act_dtype: torch.dtype,
moe_parallel_config: FusedMoEParallelConfig,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
if necessary.
Args:
hidden_size: Layer hidden-size
act_dtype: Data type of the layer activations.
moe_parallel_config: Fused MoE parallelization strategy configuration.
Return:
Rounded up hidden_size if rounding up is required based on the configs
and all2all backend.
Original hidden size otherwise.
"""
if moe_parallel_config.use_deepep_ht_kernels:
hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size, act_dtype
)
if moe_parallel_config.use_deepep_ll_kernels:
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
hidden_size
)
return hidden_size
def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
allow_new_interface: bool = False,
) -> FusedMoEPrepareAndFinalize | None:
# NOTE(rob): we are migrating each quant_method to hold the MK
# in all cases. The allow_new_interface=False flag allow us to fall
# back to the old method for methods that have not yet been migrated.
#
# In old method:
# * maybe_init_modular_kernel() calls this function. If we are
# using no Dp/Ep or naive all2all, we return None this function
# returns None and no ModularKernelMethod is created. If non-naive
# all2all is used, this returns a PrepareAndFinalize object and
# a ModularKernelMethod is created.
# In new method:
# * maybe_make_prepare_finalize() is called from the oracle. We
# always return a PrepareAndFinalize object and the quant method
# holds the ModularKernel.
if not moe.moe_parallel_config.use_all2all_kernels:
if not allow_new_interface:
return None
# For DP/TP case, fall back to naive P/F.
if moe.moe_parallel_config.dp_size > 1:
logger.info_once(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
return MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
)
else:
return MoEPrepareAndFinalizeNoEP()
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
if moe.use_pplx_kernels:
assert quant_config is not None
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
moe.max_num_tokens,
moe.hidden_dim,
moe.in_dtype,
quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=hidden_dim_bytes,
hidden_dim_scale_bytes=hidden_scale_bytes,
)
num_dispatchers = (
all2all_manager.world_size // all2all_manager.tp_group.world_size
)
# Intranode pplx a2a takes a group name while internode does not.
if not all2all_manager.internode:
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
num_local_experts=moe.num_local_experts,
num_dispatchers=num_dispatchers,
)
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
all_to_all_args = dict()
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = DeepEPHTPrepareAndFinalize(
handle,
num_dispatchers=all2all_manager.world_size,
dp_size=all2all_manager.dp_world_size,
rank_expert_offset=all2all_manager.rank * moe.num_local_experts,
)
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
global_to_physical = physical_to_global = local_expert_global_ids = None
if routing_tables is not None:
(
global_to_physical,
physical_to_global,
local_expert_global_ids,
) = routing_tables
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
num_ep_ranks=all2all_manager.world_size,
num_global_experts=moe.num_experts,
num_local_experts=moe.num_experts // all2all_manager.world_size,
)
handle = all2all_manager.get_handle(all_to_all_args)
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.quant_dtype == current_platform.fp8_dtype()
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
)
prepare_finalize = DeepEPLLPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
global_to_physical=global_to_physical,
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
)
elif moe.use_mori_kernels:
assert quant_config is not None
# Note: We may want to use FP8 dispatch just to reduce
# data movement.
use_fp8_dispatch = (
quant_config.is_per_act_token or quant_config.is_block_quantized
)
# For PTPC (per token per channel) quant, the scale dim for each token is 1
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
all_to_all_args = dict(
rank=all2all_manager.rank,
num_ep_ranks=all2all_manager.world_size,
quant_dtype=quant_config.quant_dtype,
token_hidden_size=moe.hidden_dim,
scale_dim=scale_dim,
scale_type_size=torch.float32.itemsize,
max_num_tokens_per_dp_rank=moe.max_num_tokens,
input_dtype=moe.in_dtype,
num_local_experts=moe.num_experts // all2all_manager.world_size,
num_experts_per_token=moe.experts_per_token,
)
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = MoriPrepareAndFinalize(
handle,
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
)
elif moe.use_fi_all2allv_kernels:
assert quant_config is not None
prepare_finalize = FlashInferA2APrepareAndFinalize(
num_dispatchers=all2all_manager.world_size,
)
elif moe.use_naive_all2all_kernels and allow_new_interface:
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
num_dispatchers=all2all_manager.world_size,
)
return prepare_finalize

View File

@@ -0,0 +1,447 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
FusedMoEQuantConfig,
)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
kFp8Dynamic128Sym,
kFp8Static128BlockSym,
)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import (
DeepGemmQuantScaleFMT,
fp8_m_grouped_gemm_nt_masked,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
is_deep_gemm_supported,
)
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
shape = (E, T, G)
strides = (T * G, 1, T)
if quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32,
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
]:
return shape, strides, torch.float32
assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
shape = (E, T, cdiv(G, 4))
strides = (T * cdiv(G, 4), 1, T)
return shape, strides, torch.int32
@triton.jit
def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------
input_ptr, # 16-bit activations (E, T, 2*H)
y_q_ptr, # fp8 quantized activations (E, T, H)
y_s_ptr, # 16-bit scales (E, T, G)
counts_ptr, # int32 num tokens per expert (E)
# Sizes ---------------------------------------------------------------
H: tl.constexpr, # hidden dimension (per output)
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
# Strides for input (elements) ---------------------------------------
stride_i_e,
stride_i_t,
stride_i_h,
# Strides for y_q (elements) -----------------------------------------
stride_yq_e,
stride_yq_t,
stride_yq_h,
# Strides for y_s (elements) -----------------------------------------
stride_ys_e,
stride_ys_t,
stride_ys_g,
# Stride for counts (elements)
stride_counts_e,
# Numeric params ------------------------------------------------------
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
ceil_ue8m0: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
G = H // GROUP_SIZE
# map program id -> (e, g)
pid = tl.program_id(0)
e = pid // G
g = pid % G
e = e.to(tl.int64)
g = g.to(tl.int64)
# number of valid tokens for this expert
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
cols = tl.arange(0, BLOCK).to(tl.int64)
mask = cols < BLOCK
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
base_gate_offset = base_input_offset + cols * stride_i_h
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
base_ys_offset = e * stride_ys_e + g * stride_ys_g
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
gate = tl.load(
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
).to(tl.float32)
up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0)
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
y = gate * up
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
if ceil_ue8m0:
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
def persistent_masked_m_silu_mul_quant(
y: torch.Tensor, # (E, T, 2*H)
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
num_parallel_tokens=16,
group_size: int = 128,
quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
y has shape (E, T, 2*H). The first half of the last dimension is
silu-activated, multiplied by the second half, then quantized into FP8.
We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2`
be a parallelization factor for persistent_masked_m_silu_mul_quant over the
hidden dimension.
Let `expert_offsets = [0] + [num_tokens.cumsum()]` and
`total_tokens = expert_offsets[-1]`.
persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of
thread blocks. Each thread block contains `NUM_WARPS` warps.
Every thread block needs to find it's corresponding expert by warp-parallel scanning
over the `expert_offsets` array.
The i-th warp in the first thread block processes
`[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups
sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`,
pipelining loads and computes.
The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2
can is visualized like so:
stage0 stage1
┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐
│gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│
└─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘
with the main difference between V1 and V2 being the global load
stride between warps, and between half-warps. Regarding the latter stride,
we assign the first half warp of every warp for `gate` loads and the second
half-warp to `up` loads.
Returns `(y_q, y_s)` where
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
* `y_s` depends on quant_scale_fmt,
- quant_scale_fmt == FLOAT32,
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
- quant_scale_fmt == E8M0,
`y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T)
- quant_scale_fmt == E8M0_FLOAT32_SPARSE
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
Let NUM_WARPS be the number of warps in a single thread block and
`GROUP_SIZE = 128` be the size of the quantization group.
"""
assert y.ndim == 3, "y must be (E, T, 2*H)"
E, T, H2 = y.shape
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
H = H2 // 2
G = (H + group_size - 1) // group_size
assert H % 8 == 0, "H must be divisible by 8"
assert group_size == 128, "H must be divisible by 8"
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
fp8_dtype = torch.float8_e4m3fn
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
y_s = torch.empty_strided(
ys_shape,
ys_strides,
dtype=ys_dtype,
device=y.device,
)
ceil_ue8m0 = quant_scale_fmt in [
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
DeepGemmQuantScaleFMT.UE8M0,
]
cuda_arch = current_platform.get_device_capability(
device_id=y.device.index
).to_int()
if cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant(
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
)
else:
stride_cnt_e = tokens_per_expert.stride()[0]
# Static grid over experts and H-groups.
# A loop inside the kernel handles the token dim
grid = (E * G,)
# strides (elements)
stride_i_e, stride_i_t, stride_i_h = y.stride()
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
f_info = torch.finfo(fp8_dtype)
fp8_max = f_info.max
fp8_min = f_info.min
eps: float = 1e-10
assert y_s.dtype == torch.float32, (
"_silu_mul_fp8_quant_deep_gemm does"
"not support {y_s.dtype} scales. Only torch.float32 supported."
)
_silu_mul_fp8_quant_deep_gemm[grid](
y,
y_q,
y_s,
tokens_per_expert,
H,
group_size,
stride_i_e,
stride_i_t,
stride_i_h,
stride_yq_e,
stride_yq_t,
stride_yq_h,
ys_strides[0],
ys_strides[1],
ys_strides[2],
stride_cnt_e,
eps,
fp8_min,
fp8_max,
ceil_ue8m0,
BLOCK=group_size,
NUM_STAGES=4,
num_warps=1,
)
return y_q, y_s
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int,
num_dispatchers: int,
):
"""
max_num_tokens: Maximum number of tokens from a DP Rank
num_dispatchers: The number of DP dispatchers.
quant_config: Quantization configuration
"""
super().__init__(
moe_config=moe_config,
quant_config=quant_config,
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
)
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
assert self.quant_config.use_fp8_w8a8
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.BatchedExperts
@staticmethod
def _supports_current_device() -> bool:
return is_deep_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation == MoEActivation.SILU
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
return True
def supports_chunking(self) -> bool:
return False
def supports_expert_map(self) -> bool:
return False
def supports_packed_ue8m0_act_scales(self) -> bool:
"""
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
"""
return (
is_deep_gemm_e8m0_used()
and current_platform.is_device_capability_family(100)
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Let PrepareAndFinalize::finalize() decide the impl.
return TopKWeightAndReduceDelegate()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
assert self.num_dispatchers is not None
assert self.max_num_tokens is not None
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim)
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output)
def estimate_expected_m(
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
) -> int:
dp_meta = (
get_forward_context().dp_metadata
if is_forward_context_available()
else None
)
if dp_meta is None:
logger.warning_once(
"DPMetadata unavailable. Defaulting expected_m to "
f"{max_tokens_per_expert}.",
scope="local",
)
return max_tokens_per_expert
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
total_num_tokens_replicated = total_num_tokens * topk
# Assume even load balancing
assert global_num_experts != 0
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
# clamp estimate
estimate = max(estimate, 16)
estimate = min(max_tokens_per_expert, estimate)
return estimate
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert expert_tokens_meta is not None
expert_num_tokens = expert_tokens_meta.expert_num_tokens
assert hidden_states.ndim == 3
assert self.block_shape is not None
a1q = hidden_states
_, N, K = w1.size()
assert w2.size(1) == K
E, max_num_tokens, N, K, _ = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
expected_m = self.estimate_expected_m(
global_num_experts=global_num_experts,
max_tokens_per_expert=max_num_tokens,
topk=topk_ids.size(-1),
)
fp8_m_grouped_gemm_nt_masked(
(a1q, a1q_scale),
(w1, self.w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
workspace1,
expert_num_tokens,
quant_scale_fmt=quant_scale_fmt,
)
fp8_m_grouped_gemm_nt_masked(
(a2q, a2q_scale),
(w2, self.w2_scale),
output,
expert_num_tokens,
expected_m,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"5120": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"9216": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"13312": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"17408": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"25600": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"33792": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"41984": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"50176": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"58368": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"5120": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"9216": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"13312": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"17408": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"25600": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"33792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"41984": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"50176": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"58368": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"5120": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"9216": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"13312": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"17408": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"25600": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"33792": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"41984": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"50176": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"58368": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"5120": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"9216": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"13312": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"17408": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"25600": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"33792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"41984": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"50176": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"58368": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 256,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"5120": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"9216": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"13312": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"17408": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"25600": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"33792": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"41984": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"50176": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"58368": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"5120": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"9216": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"13312": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"17408": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"25600": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"33792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"41984": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"50176": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"58368": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"5120": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"9216": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"13312": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"17408": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"25600": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"33792": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"41984": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"50176": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"58368": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,218 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"5120": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"9216": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"13312": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"17408": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"25600": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"33792": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"41984": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"50176": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"58368": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}

View File

@@ -0,0 +1,200 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}

View File

@@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}

View File

@@ -0,0 +1,123 @@
{
"triton_version": "3.4.0",
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"16384": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,139 @@
{
"triton_version": "3.5.1",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"320": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"768": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,147 @@
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,147 @@
{
"triton_version": "3.6.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
}
}

View File

@@ -0,0 +1,147 @@
{
"triton_version": "3.5.0",
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
}
}

View File

@@ -0,0 +1,122 @@
{
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8192": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
}
}

View File

@@ -0,0 +1,164 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 5
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

Some files were not shown because too many files have changed in this diff Show More