Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
9
vllm/model_executor/__init__.py
Normal file
9
vllm/model_executor/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, PackedvLLMParameter
|
||||
|
||||
__all__ = [
|
||||
"BasevLLMParameter",
|
||||
"PackedvLLMParameter",
|
||||
]
|
||||
346
vllm/model_executor/custom_op.py
Normal file
346
vllm/model_executor/custom_op.py
Normal file
@@ -0,0 +1,346 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import get_cached_compilation_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.utils import maybe_disable_graph_partition
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Dictionary of all custom ops (classes, indexed by registered name).
|
||||
# To check if an op with a name is enabled, call .enabled() on the class.
|
||||
# Examples:
|
||||
# - MyOp.enabled()
|
||||
# - op_registry["my_op"].enabled()
|
||||
op_registry: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
op_registry_oot: dict[str, type["CustomOp"] | type["PluggableLayer"]] = {}
|
||||
|
||||
|
||||
class PluggableLayer(nn.Module):
|
||||
"""
|
||||
Base class for pluggable layers.
|
||||
|
||||
A PluggableLayer is a *module-composing* abstraction: it may instantiate other
|
||||
``torch.nn.Module`` objects as sub-layers, and its functionality depends on
|
||||
these sub-layers following a generalized invocation sequence. Also, it is stateful
|
||||
and may hold parameters or buffers.
|
||||
|
||||
Unlike :class:`CustomOp`, PluggableLayer does NOT provide per-platform
|
||||
``forward_*`` dispatch. Instead, it supports out-of-tree (OOT) replacement
|
||||
of the entire layer class at instantiation time, allowing customized
|
||||
initialization and submodule composition.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
try:
|
||||
layer_class_name = cls.__name__
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
|
||||
f"was not set, possibly because it was not decorated with "
|
||||
f"@PluggableLayer.register, or it's the PluggableLayer itself."
|
||||
) from None
|
||||
|
||||
if layer_class_name not in op_registry_oot:
|
||||
layer_cls_to_instantiate = cls
|
||||
else:
|
||||
layer_cls_to_instantiate = op_registry_oot[layer_class_name]
|
||||
logger.debug(
|
||||
"Instantiating pluggable layer: %s using %s",
|
||||
layer_class_name,
|
||||
str(layer_cls_to_instantiate),
|
||||
)
|
||||
return super().__new__(layer_cls_to_instantiate)
|
||||
|
||||
# Decorator to register pluggable layers.
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(op_cls):
|
||||
assert name not in op_registry, f"Duplicate op name: {name}"
|
||||
op_cls.name = name
|
||||
op_registry[name] = op_cls
|
||||
return op_cls
|
||||
|
||||
return decorator
|
||||
|
||||
# Decorator to register out-of-tree(oot) pluggable layers.
|
||||
# For OOT pluggable layers:
|
||||
# if in-tree layer class is registered with an oot_custom_layer,
|
||||
# the oot_custom_layer will be used instead.
|
||||
@classmethod
|
||||
def register_oot(cls, _decorated_layer_cls=None, name: str | None = None):
|
||||
def decorator(layer_cls):
|
||||
reg_name = name if name is not None else cls.__name__
|
||||
assert reg_name not in op_registry_oot, f"Duplicate layer name: {reg_name}"
|
||||
layer_cls.name = reg_name
|
||||
op_registry_oot[reg_name] = layer_cls
|
||||
return layer_cls
|
||||
|
||||
if _decorated_layer_cls is None:
|
||||
# Called with parentheses: @PluggableLayer.register_oot()
|
||||
# or @PluggableLayer.register_oot(name="...")
|
||||
return decorator
|
||||
elif isinstance(_decorated_layer_cls, type): # Check if it's a class
|
||||
# Called without parentheses: @PluggableLayer.register_oot
|
||||
return decorator(_decorated_layer_cls)
|
||||
else:
|
||||
raise TypeError("Decorator can only be applied to classes.")
|
||||
|
||||
|
||||
class CustomOp(nn.Module):
|
||||
"""
|
||||
Base class for custom ops.
|
||||
Dispatches the forward method to the appropriate backend.
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
try:
|
||||
op_name = cls.__name__
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
f"Cannot instantiate '{cls.__name__}': its 'name' attribute "
|
||||
f"was not set, possibly because it was not decorated with "
|
||||
f"@CustomOp.register, or it's the CustomOp base class itself."
|
||||
) from None
|
||||
|
||||
if op_name not in op_registry_oot:
|
||||
op_cls_to_instantiate = cls
|
||||
else:
|
||||
op_cls_to_instantiate = op_registry_oot[op_name]
|
||||
logger.debug(
|
||||
"Instantiating custom op: %s using %s",
|
||||
op_name,
|
||||
str(op_cls_to_instantiate),
|
||||
)
|
||||
return super().__new__(op_cls_to_instantiate)
|
||||
|
||||
def __init__(self, *, enforce_enable: bool = False, compile_native: bool = False):
|
||||
super().__init__()
|
||||
self._enforce_enable = enforce_enable
|
||||
self._forward_method = self.dispatch_forward(compile_native=compile_native)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def forward_native(self, *args, **kwargs):
|
||||
"""PyTorch-native implementation of the forward method.
|
||||
This method is optional. If implemented, it can be used with compilers
|
||||
such as torch.compile or PyTorch XLA. Also, it can be used for testing
|
||||
purposes.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_cuda(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_hip(self, *args, **kwargs):
|
||||
# By default, we assume that HIP ops are compatible with CUDA ops.
|
||||
return self.forward_cuda(*args, **kwargs)
|
||||
|
||||
def forward_xpu(self, *args, **kwargs):
|
||||
# By default, we assume that XPU ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_cpu(self, *args, **kwargs):
|
||||
# By default, we assume that CPU ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs):
|
||||
# By default, we assume that TPU ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
# NOTE(woosuk): This is a placeholder for future extensions.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def forward_oot(self, *args, **kwargs):
|
||||
# By default, we assume that OOT ops are compatible with the
|
||||
# PyTorch-native implementation.
|
||||
return self.forward_native(*args, **kwargs)
|
||||
|
||||
def dispatch_forward(self, compile_native: bool):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
compilation_config = get_cached_compilation_config()
|
||||
|
||||
# NOTE(shen-shanshan): CustomOp object can be enforce enabled, e.g.,
|
||||
# enable device-specific kernels in ViT models when enabling graph
|
||||
# mode. By default, it will follow the compilation_config to determine
|
||||
# whether enable itself.
|
||||
# This enforce_enable mechanism will be removed after we adding a
|
||||
# separate compilation_config for multi-modal part.
|
||||
enabled = self._enforce_enable or self.enabled()
|
||||
if enabled:
|
||||
compilation_config.enabled_custom_ops.update([self.__class__.name])
|
||||
else:
|
||||
compilation_config.disabled_custom_ops.update([self.__class__.name])
|
||||
|
||||
if not enabled:
|
||||
# Compile forward_native to avoid eager torch ops if inside
|
||||
# opaque torch custom op (e.g. fused_moe, unified_attention, etc.)
|
||||
return self.maybe_compile(self.forward_native, enable=compile_native)
|
||||
|
||||
if current_platform.is_rocm():
|
||||
return self.forward_hip
|
||||
elif current_platform.is_cpu():
|
||||
return self.forward_cpu
|
||||
elif current_platform.is_tpu():
|
||||
return self.forward_tpu
|
||||
elif current_platform.is_xpu():
|
||||
return self.forward_xpu
|
||||
elif current_platform.is_out_of_tree():
|
||||
return self.forward_oot
|
||||
else:
|
||||
return self.forward_cuda
|
||||
|
||||
def maybe_compile(self, fn, *, enable: bool = True):
|
||||
"""
|
||||
Compile fn if compilation enabled.
|
||||
Useful for CustomOp instances called from within a torch custom op,
|
||||
meaning the forward call is hidden from the model-level torch.compile.
|
||||
|
||||
NOTE: this does not enable fusion across ops, so opaque custom ops
|
||||
should still be unwrapped wherever possible.
|
||||
"""
|
||||
from vllm.config.compilation import CompilationMode
|
||||
|
||||
# Do not compile if compilation disabled
|
||||
if not enable:
|
||||
return fn
|
||||
|
||||
# Do not compile if global compilation disabled
|
||||
compilation_config = get_cached_compilation_config()
|
||||
if compilation_config.mode == CompilationMode.NONE:
|
||||
return fn
|
||||
|
||||
# If eager backend is used, do not compile either
|
||||
if compilation_config.backend == "eager":
|
||||
return fn
|
||||
|
||||
compile_options = maybe_disable_graph_partition(
|
||||
current_platform.simple_compile_backend
|
||||
)
|
||||
backend = current_platform.simple_compile_backend
|
||||
|
||||
dynamic_arg_dims = getattr(self.__class__, "_dynamic_arg_dims", None)
|
||||
if dynamic_arg_dims is not None:
|
||||
compiled_fn = torch.compile(
|
||||
fn,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=compile_options,
|
||||
)
|
||||
sig = inspect.signature(fn)
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
bound = sig.bind(*args, **kwargs)
|
||||
bound.apply_defaults()
|
||||
for name, dims in dynamic_arg_dims.items():
|
||||
arg = bound.arguments.get(name)
|
||||
if arg is not None and isinstance(arg, torch.Tensor):
|
||||
dims_list = [dims] if isinstance(dims, int) else dims
|
||||
for d in dims_list:
|
||||
real_d = arg.ndim + d if d < 0 else d
|
||||
torch._dynamo.mark_dynamic(arg, real_d)
|
||||
return compiled_fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
# dynamic=True to avoid recompilations
|
||||
return torch.compile(
|
||||
fn,
|
||||
dynamic=True,
|
||||
backend=backend,
|
||||
options=compile_options,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
# if no name, then it was not registered
|
||||
compilation_config = get_cached_compilation_config()
|
||||
custom_ops = compilation_config.custom_ops
|
||||
if not hasattr(cls, "name"):
|
||||
logger.warning_once(
|
||||
"Custom op %s was not registered, which means it won't appear "
|
||||
"in the op registry. It will be enabled/disabled based on the "
|
||||
"global settings.",
|
||||
cls.__name__,
|
||||
)
|
||||
return CustomOp.default_on()
|
||||
|
||||
enabled = f"+{cls.name}" in custom_ops
|
||||
disabled = f"-{cls.name}" in custom_ops
|
||||
assert not (enabled and disabled), f"Cannot enable and disable {cls.name}"
|
||||
|
||||
return (CustomOp.default_on() or enabled) and not disabled
|
||||
|
||||
@staticmethod
|
||||
def default_on() -> bool:
|
||||
"""
|
||||
Behavior controlled by `CompilationConfig.custom_ops`: On by default if
|
||||
'all', off by default if 'none'.
|
||||
When PyTorch Inductor is used, 'none' is the default value,
|
||||
otherwise 'all'.
|
||||
"""
|
||||
compilation_config = get_cached_compilation_config()
|
||||
count_none = compilation_config.custom_ops.count("none")
|
||||
count_all = compilation_config.custom_ops.count("all")
|
||||
assert count_none + count_all == 1
|
||||
|
||||
return not count_none > 0 or count_all > 0
|
||||
|
||||
# Decorator to register custom ops.
|
||||
@classmethod
|
||||
def register(
|
||||
cls,
|
||||
name: str,
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
):
|
||||
def decorator(op_cls):
|
||||
assert name not in op_registry, f"Duplicate op name: {name}"
|
||||
op_cls.name = name
|
||||
op_cls._dynamic_arg_dims = dynamic_arg_dims
|
||||
op_registry[name] = op_cls
|
||||
return op_cls
|
||||
|
||||
return decorator
|
||||
|
||||
# Decorator to register out-of-tree(oot) custom ops.
|
||||
# For OOT custom ops:
|
||||
# if in-tree layer class is registered with an oot_custom_op layer,
|
||||
# the oot_custom_op layer will be used instead.
|
||||
# Example:
|
||||
# - @UnquantizedFusedMoEMethod.register_oot
|
||||
# class HPUUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod)
|
||||
# or
|
||||
# - @CustomOP.register_oot(name="UnquantizedFusedMoEMethod")
|
||||
@classmethod
|
||||
def register_oot(cls, _decorated_op_cls=None, name: str | None = None):
|
||||
def decorator(op_cls):
|
||||
reg_name = name if name is not None else cls.__name__
|
||||
assert reg_name not in op_registry_oot, f"Duplicate op name: {reg_name}"
|
||||
op_cls.name = reg_name
|
||||
op_registry_oot[reg_name] = op_cls
|
||||
return op_cls
|
||||
|
||||
if _decorated_op_cls is None:
|
||||
# Called with parentheses: @CustomOP.register_oot()
|
||||
# or @CustomOP.register_oot(name="...")
|
||||
# So, _decorated_op_cls is None.
|
||||
# We return the actual decorator function.
|
||||
return decorator
|
||||
elif isinstance(_decorated_op_cls, type): # Check if it's a class
|
||||
# Called without parentheses: @CustomOP.register_oot
|
||||
# The first argument is the class itself.
|
||||
# We call the 'decorator' function immediately with the class.
|
||||
return decorator(_decorated_op_cls)
|
||||
else:
|
||||
# Handle other unexpected cases if necessary
|
||||
raise TypeError("Decorator can only be applied to classes.")
|
||||
0
vllm/model_executor/kernels/__init__.py
Normal file
0
vllm/model_executor/kernels/__init__.py
Normal file
397
vllm/model_executor/kernels/linear/__init__.py
Normal file
397
vllm/model_executor/kernels/linear/__init__.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""
|
||||
This module re-exports linear kernel implementations to provide a
|
||||
stable import interface during an ongoing reorganization. Upcoming
|
||||
PRs will remove the scaled_mm and mixed_precision subdirectories
|
||||
and reorganize kernels by provider (aiter, cutlass, flashinfer, etc.)
|
||||
rather than by precision type. By centralizing exports here, we
|
||||
minimize the need to update imports across other modules when the
|
||||
internal structure changes. If you are adding a new kernel selector
|
||||
or kernel implementation, add it to this __init__.py to maintain
|
||||
import stability.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear.mixed_precision import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.xpu import (
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey
|
||||
from vllm.platforms import PlatformEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CPU: [CPUInt8ScaledMMLinearKernel],
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [AiterInt8ScaledMMLinearKernel, TritonInt8ScaledMMLinearKernel],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUFP8ScaledMMLinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
# in priority/performance order (when available)
|
||||
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
|
||||
PlatformEnum.CUDA: [
|
||||
CutlassW4A8LinearKernel,
|
||||
MacheteLinearKernel,
|
||||
AllSparkLinearKernel,
|
||||
MarlinLinearKernel,
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.ROCM: [
|
||||
ConchLinearKernel,
|
||||
ExllamaLinearKernel,
|
||||
],
|
||||
PlatformEnum.XPU: [
|
||||
XPUwNa16LinearKernel,
|
||||
],
|
||||
PlatformEnum.CPU: [
|
||||
Dynamic4bitLinearKernel,
|
||||
CPUWNA16LinearKernel,
|
||||
],
|
||||
}
|
||||
|
||||
_KernelT = TypeVar("_KernelT", bound=ScaledMMLinearKernel)
|
||||
_KernelConfigT = TypeVar("_KernelConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
def is_supported_and_can_implement_kernel(
|
||||
kernel: type[_KernelT], config: _KernelConfigT, compute_capability: int | None
|
||||
) -> tuple[bool, str]:
|
||||
# TODO: Fetch `VLLM_DISABLED_KERNELS` from vllm.envs instead.
|
||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "").split(","):
|
||||
return False, f" {kernel.__name__} is disabled by environment variable"
|
||||
|
||||
if compute_capability is None:
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
is_supported, failure_reason = kernel.is_supported(compute_capability)
|
||||
if not is_supported:
|
||||
return False, f"{kernel.__name__} {failure_reason}."
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if not can_implement:
|
||||
return (
|
||||
False,
|
||||
f"{kernel.__name__} {failure_reason}.",
|
||||
)
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
def choose_scaled_mm_linear_kernel(
|
||||
config: _KernelConfigT,
|
||||
possible_kernels: dict[PlatformEnum, list[type[_KernelT]]],
|
||||
compute_capability: int | None = None,
|
||||
force_kernel: type[_KernelT] | None = None,
|
||||
) -> type[_KernelT]:
|
||||
"""
|
||||
Choose a _KernelT that can implement the given config for the
|
||||
given compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (_KernelConfigT): Description of the linear layer
|
||||
to be implemented.
|
||||
possible_kernels (dict[PlatformEnum, list[_KernelT]]): A
|
||||
dictionary of platforms and their list of possible kernels.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get the
|
||||
compute capability. Defaults to None.
|
||||
force_kernel (Optional[type[_KernelT]]): An Optional forced kernel to override
|
||||
the possible_kernels if it can be implemented. If None, it will only try the
|
||||
possible kernels.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
_KernelT: Chosen kernel.
|
||||
"""
|
||||
|
||||
failure_reason_list = []
|
||||
|
||||
if force_kernel is not None:
|
||||
can_implement, failure_reason = is_supported_and_can_implement_kernel(
|
||||
force_kernel, config, compute_capability
|
||||
)
|
||||
if can_implement:
|
||||
return force_kernel
|
||||
|
||||
logger.info_once(
|
||||
"Tried to force %s, but the kernel couldn't be implemented",
|
||||
force_kernel.__name__,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
for kernel in possible_kernels[current_platform._enum]:
|
||||
is_supported_and_can_implement, failure_reason = (
|
||||
is_supported_and_can_implement_kernel(kernel, config, compute_capability)
|
||||
)
|
||||
if is_supported_and_can_implement:
|
||||
return kernel
|
||||
failure_reason_list.append(failure_reason)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"ScaledMM linear layer. Reasons: \n" + "\n".join(failure_reason_list)
|
||||
)
|
||||
|
||||
|
||||
def init_fp8_linear_kernel(
|
||||
activation_quant_key: QuantKey,
|
||||
weight_quant_key: QuantKey,
|
||||
out_dtype: torch.dtype,
|
||||
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
|
||||
module_name: str | None = None,
|
||||
) -> FP8ScaledMMLinearKernel:
|
||||
scaled_mm_linear_kernel_config = FP8ScaledMMLinearLayerConfig(
|
||||
weight_quant_key=weight_quant_key,
|
||||
activation_quant_key=activation_quant_key,
|
||||
out_dtype=out_dtype,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config, _POSSIBLE_FP8_KERNELS, force_kernel=force_kernel
|
||||
)
|
||||
|
||||
if module_name:
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
scaled_mm_linear_kernel_config,
|
||||
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
|
||||
)
|
||||
|
||||
|
||||
def init_int8_linear_kernel(
|
||||
is_channelwise: bool,
|
||||
is_static_input_scheme: bool,
|
||||
input_symmetric: bool,
|
||||
module_name: str,
|
||||
) -> Int8ScaledMMLinearKernel:
|
||||
config = Int8ScaledMMLinearLayerConfig(
|
||||
is_channelwise=is_channelwise,
|
||||
is_static_input_scheme=is_static_input_scheme,
|
||||
input_symmetric=input_symmetric,
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
config,
|
||||
_POSSIBLE_INT8_KERNELS,
|
||||
)
|
||||
|
||||
logger.info_once(
|
||||
"Selected %s for %s",
|
||||
kernel_type.__name__,
|
||||
module_name,
|
||||
scope="global",
|
||||
)
|
||||
|
||||
return kernel_type(
|
||||
config,
|
||||
layer_param_names=[
|
||||
"weight",
|
||||
"weight_scale",
|
||||
"input_scale",
|
||||
"input_zero_point",
|
||||
"azp_adj",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def choose_mp_linear_kernel(
|
||||
config: MPLinearLayerConfig, compute_capability: int | None = None
|
||||
) -> type[MPLinearKernel]:
|
||||
"""
|
||||
Choose an MPLinearKernel that can implement the given config for the given
|
||||
compute capability. Attempts to choose the best kernel in terms of
|
||||
performance.
|
||||
|
||||
Args:
|
||||
config (MPLinearLayerConfig): Description of the linear layer to be
|
||||
implemented.
|
||||
compute_capability (Optional[int], optional): The compute capability of
|
||||
the target device, if None uses `current_platform` to get
|
||||
the compute capability. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: If no kernel can implement the given config.
|
||||
|
||||
Returns:
|
||||
type[MPLinearKernel]: Chosen kernel.
|
||||
"""
|
||||
if compute_capability is None:
|
||||
if current_platform is None:
|
||||
raise ValueError("Cannot determine compute capability")
|
||||
_cc = current_platform.get_device_capability()
|
||||
if _cc is not None:
|
||||
compute_capability = _cc[0] * 10 + _cc[1]
|
||||
|
||||
failure_reasons = []
|
||||
for kernel in _POSSIBLE_KERNELS[current_platform._enum]:
|
||||
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} disabled by environment variable"
|
||||
)
|
||||
continue
|
||||
if (
|
||||
compute_capability is not None
|
||||
and kernel.get_min_capability() > compute_capability
|
||||
):
|
||||
failure_reasons.append(
|
||||
f"{kernel.__name__} requires capability "
|
||||
f"{kernel.get_min_capability()}, current compute "
|
||||
f" capability is {compute_capability}"
|
||||
)
|
||||
continue
|
||||
|
||||
can_implement, failure_reason = kernel.can_implement(config)
|
||||
if can_implement:
|
||||
return kernel
|
||||
else:
|
||||
failure_reasons.append(
|
||||
f" {kernel.__name__} cannot implement due to: {failure_reason}"
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
"Failed to find a kernel that can implement the "
|
||||
"WNA16 linear layer. Reasons: \n" + "\n".join(failure_reasons)
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"init_fp8_linear_kernel",
|
||||
"init_int8_linear_kernel",
|
||||
"choose_mp_linear_kernel",
|
||||
"FP8ScaledMMLinearKernel",
|
||||
"Int8ScaledMMLinearKernel",
|
||||
"ScaledMMLinearKernel",
|
||||
"FP8ScaledMMLinearLayerConfig",
|
||||
"Int8ScaledMMLinearLayerConfig",
|
||||
"ScaledMMLinearLayerConfig",
|
||||
"AiterInt8ScaledMMLinearKernel",
|
||||
"CPUInt8ScaledMMLinearKernel",
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
@@ -0,0 +1,94 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
|
||||
@dataclass
|
||||
class MPLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
out_type: torch.dtype | None = None
|
||||
|
||||
|
||||
class MPLinearKernel(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c: MPLinearLayerConfig,
|
||||
w_q_param_name: str,
|
||||
w_s_param_name: str,
|
||||
w_zp_param_name: str | None = None,
|
||||
w_gidx_param_name: str | None = None,
|
||||
) -> None:
|
||||
assert self.can_implement(c)
|
||||
self.config = c
|
||||
self.w_q_name = w_q_param_name
|
||||
self.w_s_name = w_s_param_name
|
||||
if c.zero_points:
|
||||
assert w_zp_param_name is not None
|
||||
if c.has_g_idx:
|
||||
assert w_gidx_param_name is not None
|
||||
self.w_zp_name = w_zp_param_name
|
||||
self.w_gidx_name = w_gidx_param_name
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def _transform_param(
|
||||
self, layer: torch.nn.Module, name: str | None, fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||
)
|
||||
|
||||
def _get_weight_params(
|
||||
self, layer: torch.nn.Module
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
torch.Tensor | None, # w_zp,
|
||||
torch.Tensor | None, # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.allspark import (
|
||||
AllSparkLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.conch import (
|
||||
ConchLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cpu import (
|
||||
CPUWNA16LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.cutlass import (
|
||||
CutlassW4A8LinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.dynamic_4bit import (
|
||||
Dynamic4bitLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.exllama import (
|
||||
ExllamaLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.machete import (
|
||||
MacheteLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.marlin import (
|
||||
MarlinLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.MPLinearKernel import (
|
||||
MPLinearKernel,
|
||||
MPLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
|
||||
XPUwNa16LinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MPLinearKernel",
|
||||
"MPLinearLayerConfig",
|
||||
"AllSparkLinearKernel",
|
||||
"ConchLinearKernel",
|
||||
"CPUWNA16LinearKernel",
|
||||
"CutlassW4A8LinearKernel",
|
||||
"Dynamic4bitLinearKernel",
|
||||
"ExllamaLinearKernel",
|
||||
"MacheteLinearKernel",
|
||||
"MarlinLinearKernel",
|
||||
"XPUwNa16LinearKernel",
|
||||
]
|
||||
116
vllm/model_executor/kernels/linear/mixed_precision/allspark.py
Normal file
116
vllm/model_executor/kernels/linear/mixed_precision/allspark.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.allspark_utils import (
|
||||
ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
check_allspark_supported_dtype_shape,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class AllSparkLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering currently not supported by AllSpark"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by AllSpark"
|
||||
|
||||
return check_allspark_supported_dtype_shape(
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.group_size,
|
||||
c.weight_type,
|
||||
c.act_type,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
|
||||
# prepare the parameters required for the kernel
|
||||
properties = torch.cuda.get_device_properties(device.index)
|
||||
sm_count = num_compute_units(device.index)
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
gemm_args = {}
|
||||
gemm_args["sm_count"] = sm_count
|
||||
gemm_args["sm_version"] = sm_version
|
||||
|
||||
self.gemm_args = gemm_args
|
||||
|
||||
# transform param weight, scale
|
||||
old_weight_param = getattr(layer, self.w_q_name)
|
||||
old_scale_param = getattr(layer, self.w_s_name)
|
||||
|
||||
assert isinstance(old_weight_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_weight_param, input_dim=0, output_dim=1, packed_dim=0)
|
||||
|
||||
assert isinstance(old_scale_param, BasevLLMParameter)
|
||||
permute_param_layout_(old_scale_param, input_dim=0, output_dim=1)
|
||||
|
||||
# unpack weight from K / 4 x N int32 to K x N uint8
|
||||
new_weight_param = torch.nn.Parameter(
|
||||
old_weight_param.data, requires_grad=False
|
||||
)
|
||||
new_weight_param.data = (
|
||||
new_weight_param.data.t().contiguous().view(dtype=torch.uint8)
|
||||
)
|
||||
new_weight_param.data = new_weight_param.data.t().contiguous()
|
||||
|
||||
new_scale_param = torch.nn.Parameter(old_scale_param.data, requires_grad=False)
|
||||
|
||||
# reorder K x N weight as N32K16 format for Ampere W8A16
|
||||
new_weight_param.data, new_scale_param.data, _ = ops.allspark_repack_weight(
|
||||
new_weight_param.data, new_scale_param.data, None, c.zero_points
|
||||
)
|
||||
|
||||
replace_parameter(layer, self.w_q_name, new_weight_param.data)
|
||||
replace_parameter(layer, self.w_s_name, new_scale_param.data)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
gemm_args = self.gemm_args
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
output = ops.allspark_w8a16_gemm(
|
||||
a=reshaped_x,
|
||||
b_qweight=w_q,
|
||||
b_scales=w_s,
|
||||
b_qzeros=None,
|
||||
n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
sm_count=gemm_args["sm_count"],
|
||||
sm_version=gemm_args["sm_version"],
|
||||
CUBLAS_M_THRESHOLD=ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD,
|
||||
has_zp=c.zero_points,
|
||||
n32k16_reorder=True,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
140
vllm/model_executor/kernels/linear/mixed_precision/conch.py
Normal file
140
vllm/model_executor/kernels/linear/mixed_precision/conch.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from importlib.util import find_spec
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CONCH_SUPPORTED_WEIGHT_TYPES: Final = [
|
||||
scalar_types.uint4,
|
||||
scalar_types.uint8,
|
||||
scalar_types.uint4b8,
|
||||
scalar_types.uint8b128,
|
||||
]
|
||||
_CONCH_SUPPORTED_GROUP_SIZES: Final = [-1, 128]
|
||||
|
||||
|
||||
class ConchLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_type not in _CONCH_SUPPORTED_WEIGHT_TYPES:
|
||||
error_msg = (
|
||||
f"Weight type ({c.weight_type}) not supported by "
|
||||
"ConchLinearKernel, supported types are: "
|
||||
f"{_CONCH_SUPPORTED_WEIGHT_TYPES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if c.group_size not in _CONCH_SUPPORTED_GROUP_SIZES:
|
||||
error_msg = (
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"ConchLinearKernel, supported group sizes are: "
|
||||
f"{_CONCH_SUPPORTED_GROUP_SIZES}"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
if find_spec("conch") is None:
|
||||
error_msg = (
|
||||
"conch-triton-kernels is not installed, please "
|
||||
"install it via `pip install conch-triton-kernels` "
|
||||
"and try again!"
|
||||
)
|
||||
return False, error_msg
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zero_point` is: {input_dim = 1, output_dim = 0, packed_dim = 0}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
# Zero points are stored PACKED as [N//pack_factor, K//G]
|
||||
# The Conch kernel expects UNPACKED zeros: [K//G, N]
|
||||
# We need to unpack and reorder
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
packed = x.data # shape: [N//pack_factor, K//G], dtype: int32
|
||||
|
||||
# Determine packing based on weight bit width
|
||||
size_bits = self.config.weight_type.size_bits
|
||||
pack_factor = 32 // size_bits # 8 for 4-bit, 4 for 8-bit
|
||||
mask = (1 << size_bits) - 1 # 0xF for 4-bit, 0xFF for 8-bit
|
||||
|
||||
n_packed, k_groups = packed.shape
|
||||
n_full = n_packed * pack_factor
|
||||
|
||||
# Unpack using vectorized bitwise ops
|
||||
# shifts = [0, size_bits, 2*size_bits, ...] for each packed position
|
||||
shifts = torch.arange(
|
||||
0, 32, size_bits, dtype=torch.int32, device=packed.device
|
||||
)
|
||||
# packed: [N//pack_factor, K//G] -> [N//pack_factor, K//G, 1]
|
||||
# shifts: [pack_factor] -> [1, 1, pack_factor]
|
||||
# Result: [N//pack_factor, K//G, pack_factor]
|
||||
unpacked = (packed.unsqueeze(-1) >> shifts) & mask
|
||||
|
||||
# Permute to [K//G, N//pack_factor, pack_factor] then reshape to [K//G, N]
|
||||
unpacked = unpacked.permute(1, 0, 2).reshape(k_groups, n_full)
|
||||
|
||||
x.data = unpacked.to(torch.uint8).contiguous()
|
||||
|
||||
# Update metadata - zeros are no longer packed
|
||||
if hasattr(x, "_input_dim"):
|
||||
x._input_dim = 0
|
||||
if hasattr(x, "_output_dim"):
|
||||
x._output_dim = 1
|
||||
if hasattr(x, "_packed_factor"):
|
||||
x._packed_factor = 1
|
||||
return x
|
||||
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if self.config.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
from conch.ops.quantization.gemm import mixed_precision_gemm
|
||||
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
output = mixed_precision_gemm(
|
||||
x=x,
|
||||
w_q_packed=w_q.data,
|
||||
w_s=w_s.data,
|
||||
w_zp=w_zp.data if w_zp is not None else None,
|
||||
weight_size_bits=self.config.weight_type.size_bits,
|
||||
weight_bias=self.config.weight_type.bias,
|
||||
group_size=self.config.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
126
vllm/model_executor/kernels/linear/mixed_precision/cpu.py
Normal file
126
vllm/model_executor/kernels/linear/mixed_precision/cpu.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_CPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
|
||||
class CPUWNA16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "CPUWNA16 only supported on CPU"
|
||||
|
||||
if c.weight_type not in _CPUWNA16_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CPUWNA16, supported types are: "
|
||||
f"{_CPUWNA16_SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.group_size != -1 and c.group_size % 2 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"CPUWNA16, supported group sizes are multiples of 2",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[0] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Input size ({c.partition_weight_shape[0]}) not supported by "
|
||||
"CPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Output size ({c.partition_weight_shape[1]}) not supported by "
|
||||
"CPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def _process_gptq_weights(self, layer: torch.nn.Module):
|
||||
packed_weight = layer.qweight.data
|
||||
bits = self.config.weight_type.mantissa
|
||||
pack_factor = 32 // bits
|
||||
p_w_k, p_w_n = packed_weight.size()
|
||||
input_size = p_w_k * pack_factor
|
||||
output_size = p_w_n
|
||||
isa_hint = _get_isa_hint(layer.scales.dtype)
|
||||
layer.isa_hint = isa_hint
|
||||
|
||||
layer.qzeros = None
|
||||
if not self.config.has_g_idx:
|
||||
layer.g_idx = None
|
||||
|
||||
# convert input dim packed to output dim packed
|
||||
weight = unpack_quantized_values_into_int32(
|
||||
packed_weight, self.config.weight_type, 1
|
||||
).view(p_w_k, p_w_n, pack_factor)
|
||||
weight = weight.permute(0, 2, 1).reshape(input_size, output_size).contiguous()
|
||||
weight = pack_quantized_values_into_int32(weight, self.config.weight_type, 1)
|
||||
# make 16 output channel as a block and transpose to the make
|
||||
# the block contigous
|
||||
weight = (
|
||||
weight.view(input_size, -1, 16 // pack_factor)
|
||||
.permute(1, 0, 2)
|
||||
.reshape(-1, input_size * 16 // pack_factor)
|
||||
.contiguous()
|
||||
)
|
||||
layer.qweight.data = weight
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
if not self.config.zero_points:
|
||||
# GPTQ
|
||||
self._process_gptq_weights(layer)
|
||||
else:
|
||||
# AWQ
|
||||
raise NotImplementedError("AWQ is not supported in CPUWNA16LinearKernel")
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x = ops.cpu_gemm_wna16(
|
||||
input=x,
|
||||
q_weight=layer.qweight,
|
||||
scales=layer.scales,
|
||||
zeros=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
bias=bias,
|
||||
pack_factor=8, # 32 // 4
|
||||
isa_hint=layer.isa_hint,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _get_isa_hint(dtype: torch.dtype) -> str:
|
||||
supports_amx = torch._C._cpu._is_amx_tile_supported()
|
||||
if supports_amx and dtype in (torch.bfloat16,):
|
||||
return "amx"
|
||||
else:
|
||||
return "vec"
|
||||
131
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py
Normal file
131
vllm/model_executor/kernels/linear/mixed_precision/cutlass.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape,
|
||||
convert_bf16_scales_to_fp8,
|
||||
convert_packed_uint4b8_to_signed_int4_inplace,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class CutlassW4A8LinearKernel(MPLinearKernel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# dynamic per-tok fp8 activation quantization
|
||||
self.quant_fp8 = QuantFP8(static=False, group_shape=GroupShape.PER_TOKEN)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "CUTLASS only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "CUTLASS W4A8 requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.act_type != torch.float8_e4m3fn:
|
||||
return False, "CUTLASS W4A8 only supports FP8 (e4m3) activations"
|
||||
|
||||
if c.has_g_idx:
|
||||
return False, "Act reordering not supported by CUTLASS W4A8"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points not supported by CUTLASS W4A8"
|
||||
|
||||
if c.weight_type != scalar_types.int4:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"CUTLASS W4A8, only supported int4",
|
||||
)
|
||||
|
||||
if c.group_size != 128:
|
||||
return False, "Only group_size 128 is supported"
|
||||
|
||||
in_features, out_features = c.partition_weight_shape
|
||||
if in_features % 128 or out_features % 128:
|
||||
return (
|
||||
False,
|
||||
f"K and N must be divisible by 128, got {c.partition_weight_shape}",
|
||||
)
|
||||
|
||||
if c.out_type != torch.bfloat16:
|
||||
return (
|
||||
False,
|
||||
f"Only bfloat16 output type currently supportedgot {c.out_type=}",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
convert_packed_uint4b8_to_signed_int4_inplace(x.data)
|
||||
torch.cuda.synchronize()
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.cutlass_encode_and_reorder_int4b(x.data.t().contiguous().t())
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous().to(torch.float8_e4m3fn)
|
||||
x.data = ops.cutlass_pack_scale_fp8(x.data)
|
||||
return x
|
||||
|
||||
w_s = getattr(layer, self.w_s_name)
|
||||
fp8_scales, chan_scales = convert_bf16_scales_to_fp8(self.quant_fp8, w_s.data)
|
||||
w_s.data = fp8_scales
|
||||
|
||||
# register per-channel scales
|
||||
layer.register_parameter(
|
||||
"weight_chan_scale", torch.nn.Parameter(chan_scales, requires_grad=False)
|
||||
)
|
||||
|
||||
# Encode/reorder weights and pack scales
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, _, _ = self._get_weight_params(layer)
|
||||
w_ch_s = layer.weight_chan_scale
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
x_2d, act_scales = self.quant_fp8(x_2d)
|
||||
output = ops.cutlass_w4a8_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
a_token_scales=act_scales,
|
||||
b_channel_scales=w_ch_s,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
# This implementation is for the KleidiAI-accelerated w4a8int quantization
|
||||
# scheme on Arm CPUs:
|
||||
# torch.ops.aten._dyn_quant_matmul_4bit performs dynamic quantized matmul
|
||||
# it takes:
|
||||
# - int4 weights packed along with bias/scales by
|
||||
# torch.ops.aten._dyn_quant_pack_4bit_weight
|
||||
# - float32/bfloat16 activations
|
||||
# then it leverages KleidiAI ukernels that:
|
||||
# - dynamically quantize the activations to int8
|
||||
# - unpack the int4 weights to int8
|
||||
# - perform int8 x int8 -> int32 matmul
|
||||
# - dequantize the int32 output to float32/bfloat16 outputs
|
||||
class Dynamic4bitLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.int4]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "Only CPU is supported"
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return False, f"Unsupported quant type {c.weight_type}"
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.ARM
|
||||
and c.act_type
|
||||
not in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
]
|
||||
):
|
||||
return (
|
||||
False,
|
||||
"Dynamic4bitLinearKernel on Arm requires Float32 or"
|
||||
" BFloat16 or Float16 activations",
|
||||
)
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
if current_platform.get_cpu_architecture() == CpuArchEnum.ARM:
|
||||
try:
|
||||
# Attempt to retrieve the operation
|
||||
_ = torch.ops.aten._dyn_quant_matmul_4bit
|
||||
except AttributeError:
|
||||
return (
|
||||
False,
|
||||
f"PyTorch {torch.__version__} does not support"
|
||||
" _dyn_quant_matmul_4bit. Install a newer version",
|
||||
)
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
packed_weight = getattr(layer, self.w_q_name)
|
||||
packed_weight = packed_weight.add(8)
|
||||
uint8_packed = (packed_weight[::, 1::2] << 4 | packed_weight[::, ::2]).to(
|
||||
torch.uint8
|
||||
)
|
||||
|
||||
scales = getattr(layer, self.w_s_name)
|
||||
block_size = c.group_size
|
||||
|
||||
# Handle scaling factors for partitioned weights
|
||||
if block_size == c.partition_weight_shape[0]:
|
||||
scales = scales.to(
|
||||
torch.float32
|
||||
) # Float32 & Bfloat16 variants requires float32 scales
|
||||
scales = scales.view(-1, 1) # Channel-wise scales
|
||||
if layer.bias is not None:
|
||||
# Float32 & Bfloat16 variants requires float32 bias
|
||||
replace_parameter(
|
||||
layer,
|
||||
"bias",
|
||||
torch.nn.Parameter(
|
||||
layer.bias.to(torch.float32), requires_grad=False
|
||||
),
|
||||
)
|
||||
else:
|
||||
# KleidiAI kernel requires bfloat16 scales with groupwise scheme
|
||||
scales = scales.to(torch.bfloat16)
|
||||
|
||||
# Repack weights as per kernel requirement
|
||||
w = torch.ops.aten._dyn_quant_pack_4bit_weight(
|
||||
uint8_packed,
|
||||
scales,
|
||||
layer.bias,
|
||||
block_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
)
|
||||
replace_parameter(
|
||||
layer, self.w_q_name, torch.nn.Parameter(w, requires_grad=False)
|
||||
)
|
||||
setattr(layer, self.w_s_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# PyTorch / KleidiAI kernels natively support the following configs:
|
||||
# - channelwise with bfloat16 / float32 activations
|
||||
# - groupwise with float32 activations
|
||||
# To support:
|
||||
# - groupwise with bfloat16/float16 activations: we need to upcast
|
||||
# activations to float32 before matmul and downcast back to bfloat16/float16
|
||||
# - channelwise with float16 activations, we need to upcast activations to
|
||||
# float32 before matmul and downcast back to float16
|
||||
# Note: these activations will be dynamically quantized to int8 by the kernel.
|
||||
|
||||
c = self.config
|
||||
is_groupwise = c.group_size != c.partition_weight_shape[0]
|
||||
# dtype of activations before they get dynamically quantized to int8
|
||||
original_pre_quant_act_dtype = x.dtype
|
||||
pre_quant_act_dtype = original_pre_quant_act_dtype
|
||||
if (
|
||||
is_groupwise and pre_quant_act_dtype == torch.bfloat16
|
||||
) or pre_quant_act_dtype == torch.float16:
|
||||
pre_quant_act_dtype = torch.float32
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
x_2d = x_2d.to(pre_quant_act_dtype)
|
||||
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q = getattr(layer, self.w_q_name)
|
||||
output = torch.ops.aten._dyn_quant_matmul_4bit(
|
||||
x_2d,
|
||||
w_q,
|
||||
c.group_size,
|
||||
c.partition_weight_shape[0],
|
||||
c.partition_weight_shape[1],
|
||||
).reshape(out_shape)
|
||||
|
||||
if pre_quant_act_dtype != original_pre_quant_act_dtype:
|
||||
output = output.to(original_pre_quant_act_dtype)
|
||||
return output
|
||||
168
vllm/model_executor/kernels/linear/mixed_precision/exllama.py
Normal file
168
vllm/model_executor/kernels/linear/mixed_precision/exllama.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class ExllamaLinearKernel(MPLinearKernel):
|
||||
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
|
||||
# currently untested so not added to the list
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 60
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda_alike():
|
||||
return (
|
||||
False,
|
||||
"Exllama is only supported on CUDA and ROCm",
|
||||
)
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Exllama, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||
return (
|
||||
False,
|
||||
"Output features must be a multiple of the pack "
|
||||
"factor (32 / num_bits) so that we can correctly "
|
||||
"pack the zero points",
|
||||
)
|
||||
|
||||
if c.act_type != torch.float16:
|
||||
return False, "Exllama only supports float16 activations"
|
||||
|
||||
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Exllama, supported types are: "
|
||||
f"{cls.SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
|
||||
if c.full_weight_shape[0] % c.group_size != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) does not evenly divide"
|
||||
" the number of input features "
|
||||
f"({c.full_weight_shape[0]})",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
# For Exllama, we need to set a zero-point tensor if there is not one
|
||||
if not c.zero_points:
|
||||
self.w_zp_name = "qzeros"
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
groups = c.partition_weight_shape[0] // c.group_size
|
||||
out_features = c.partition_weight_shape[1]
|
||||
|
||||
if c.weight_type.has_bias():
|
||||
# if the type has a bias we have to create a zeros tensor that
|
||||
# contains the bias values repeated for each group (-1 due to
|
||||
# a bug in the original GPTQ checkpoint format leading to
|
||||
# exllama kernel adding 1 to the zero points during inference)
|
||||
# Documentation of the bug can be found here:
|
||||
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||
zeros = torch.full(
|
||||
(groups, out_features),
|
||||
c.weight_type.bias - 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"A 0 zero-point is not supported by Exllama due to "
|
||||
"a bug in the original GPTQ checkpoint format leading to "
|
||||
"exllama kernel adding 1 to the zero points during "
|
||||
"inference"
|
||||
)
|
||||
zeros = pack_quantized_values_into_int32(zeros, c.weight_type, packed_dim=1)
|
||||
setattr(
|
||||
layer, self.w_zp_name, torch.nn.Parameter(zeros, requires_grad=False)
|
||||
)
|
||||
|
||||
if c.has_g_idx:
|
||||
|
||||
def transform_w_g_idx(x):
|
||||
# Exllama wants the permutation array instead of the group
|
||||
# indices
|
||||
return torch.argsort(x).to(torch.int)
|
||||
|
||||
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) # type: ignore
|
||||
else:
|
||||
self.w_gidx_name = "g_idx"
|
||||
empty_g_idx = torch.nn.Parameter(
|
||||
torch.empty((0,), dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
assert self.w_gidx_name is not None
|
||||
g_idx = getattr(layer, self.w_gidx_name)
|
||||
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x_cont = x.data.contiguous()
|
||||
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||
return x_cont
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x.to(dtype=c.act_type)
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||
|
||||
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
|
||||
# However, the MPLinearLayerConfig doesn't contain format info.
|
||||
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
|
||||
use_v2_format = False
|
||||
|
||||
assert w_zp is not None, "Zero points are required by Exllama"
|
||||
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||
output = ops.gptq_gemm(
|
||||
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
159
vllm/model_executor/kernels/linear/mixed_precision/machete.py
Normal file
159
vllm/model_executor/kernels/linear/mixed_precision/machete.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
||||
check_machete_supports_shape,
|
||||
query_machete_supported_group_sizes,
|
||||
query_machete_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
pack_quantized_values_into_int32,
|
||||
unpack_quantized_values_into_int32,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MacheteLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 90
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Machete uses CUTLASS, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Machete only supported on CUDA"
|
||||
|
||||
if not current_platform.is_device_capability(90):
|
||||
return False, "Machete requires compute capability of 90 (Hopper)"
|
||||
|
||||
if c.has_g_idx and c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return (
|
||||
False,
|
||||
"Act reordering currently not supported by Machete, "
|
||||
"when the input features are partitioned across "
|
||||
"devices",
|
||||
)
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(c.zero_points):
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"Machete, supported types are: "
|
||||
f"{query_machete_supported_quant_types(c.zero_points)}",
|
||||
)
|
||||
|
||||
if c.group_size not in query_machete_supported_group_sizes(c.act_type):
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Machete, supported group sizes are: "
|
||||
f"{query_machete_supported_group_sizes(c.act_type)}",
|
||||
)
|
||||
|
||||
return check_machete_supports_shape(
|
||||
c.partition_weight_shape[0], c.partition_weight_shape[1]
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
c = self.config
|
||||
|
||||
if c.has_g_idx:
|
||||
assert self.w_gidx_name is not None
|
||||
perm = torch.argsort(getattr(layer, self.w_gidx_name)).to(torch.int)
|
||||
|
||||
self.act_perm = lambda x: x[:, perm]
|
||||
# use `ops.permute_cols` if possible
|
||||
if (
|
||||
c.act_type in [torch.float16, torch.bfloat16]
|
||||
and c.partition_weight_shape[0] % 8 == 0
|
||||
):
|
||||
self.act_perm = partial(ops.permute_cols, perm=perm)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
if c.has_g_idx:
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=0
|
||||
)
|
||||
x_perm = x_unpacked[perm, :]
|
||||
x.data = pack_quantized_values_into_int32(
|
||||
x_perm, c.weight_type, packed_dim=0
|
||||
)
|
||||
x.data = ops.machete_prepack_B(
|
||||
x.data.t().contiguous().t(),
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = x.data.contiguous()
|
||||
return x
|
||||
|
||||
def transform_w_zp(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
|
||||
x_unpacked = unpack_quantized_values_into_int32(
|
||||
x.data, c.weight_type, packed_dim=1
|
||||
)
|
||||
w_s = getattr(layer, self.w_s_name).data
|
||||
# pre-apply scales to zero-points
|
||||
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
|
||||
return x
|
||||
|
||||
# Repack weights and scales for Machete
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if c.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
|
||||
|
||||
x_2d = x.reshape(-1, x.shape[-1])
|
||||
out_shape = x.shape[:-1] + (c.partition_weight_shape[1],)
|
||||
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
if c.zero_points:
|
||||
assert w_zp is not None
|
||||
else:
|
||||
w_zp = None
|
||||
|
||||
output = ops.machete_mm(
|
||||
a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=w_zp,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
200
vllm/model_executor/kernels/linear/mixed_precision/marlin.py
Normal file
200
vllm/model_executor/kernels/linear/mixed_precision/marlin.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
marlin_sort_g_idx,
|
||||
marlin_zero_points,
|
||||
query_marlin_supported_quant_types,
|
||||
unpack_cols,
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
|
||||
class MarlinLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 75
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
# Marlin uses inline PTX, so it can only be compatible with Nvidia
|
||||
if not current_platform.is_cuda():
|
||||
return False, "Marlin only supported on CUDA"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by"
|
||||
f" Marlin, supported types are: {quant_types}",
|
||||
)
|
||||
|
||||
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"Marlin, supported group sizes are: "
|
||||
f"{MARLIN_SUPPORTED_GROUP_SIZES}",
|
||||
)
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
# note assumes that
|
||||
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
# `weight_scale` is: {input_dim = 0, output_dim = 1}
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert c.weight_type == scalar_types.uint4b8, (
|
||||
"W8A8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if c.act_type == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
|
||||
getattr(layer, self.w_s_name).data = (
|
||||
getattr(layer, self.w_s_name).data * 512
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
if self.w_gidx_name is None:
|
||||
self.w_gidx_name = "g_idx"
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = ops.gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
if c.group_size == -1:
|
||||
num_groups = 1
|
||||
else:
|
||||
num_groups = c.partition_weight_shape[0] // c.group_size
|
||||
|
||||
if c.act_type == torch.int8 and num_groups > 1:
|
||||
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
|
||||
layer.register_parameter(
|
||||
"input_global_scale",
|
||||
torch.nn.Parameter(input_global_scale, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
layer.input_global_scale = None
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
self._transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.config
|
||||
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=c.act_type,
|
||||
)
|
||||
88
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
Normal file
88
vllm/model_executor/kernels/linear/mixed_precision/xpu.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
_XPUWNA16_SUPPORTED_QUANT_TYPES = (scalar_types.uint4, scalar_types.uint4b8)
|
||||
|
||||
|
||||
class XPUwNa16LinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return -1
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: MPLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUwNa16 only supported on XPU"
|
||||
|
||||
if c.act_type != torch.bfloat16 and c.act_type != torch.float16:
|
||||
return False, "XPUwNa16 only supports BF16/FP16 activations"
|
||||
|
||||
if c.weight_type not in _XPUWNA16_SUPPORTED_QUANT_TYPES:
|
||||
return (
|
||||
False,
|
||||
f"Quant type ({c.weight_type}) not supported by "
|
||||
"XPUwNa16, supported types are: "
|
||||
f"{_XPUWNA16_SUPPORTED_QUANT_TYPES}",
|
||||
)
|
||||
if c.group_size != -1 and c.group_size % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Group size ({c.group_size}) not supported by "
|
||||
"XPUwNa16, supported group sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[0] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Input size ({c.partition_weight_shape[0]}) not supported by "
|
||||
"XPUwNa16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
if c.partition_weight_shape[1] % 32 != 0:
|
||||
return (
|
||||
False,
|
||||
f"Output size ({c.partition_weight_shape[1]}) not supported by "
|
||||
"XPUWNA16, supported sizes are multiples of 32",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
layer.weight_scale.data = layer.weight_scale.t().contiguous()
|
||||
|
||||
if self.config.zero_points:
|
||||
layer.weight_zero_point.data = layer.weight_zero_point.t().contiguous()
|
||||
else:
|
||||
weight_zero_point = torch.Tensor([8]).to(torch.int8).to("xpu")
|
||||
layer.weight_zero_point = Parameter(weight_zero_point, requires_grad=False)
|
||||
if self.config.has_g_idx:
|
||||
layer.g_idx.data = layer.g_idx.t().contiguous()
|
||||
else:
|
||||
layer.g_idx = None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
out = torch.ops._xpu_C.int4_gemm_w4a16(
|
||||
reshaped_x,
|
||||
layer.weight_packed.t(),
|
||||
bias,
|
||||
layer.weight_scale,
|
||||
layer.weight_zero_point,
|
||||
self.config.group_size,
|
||||
layer.g_idx,
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,187 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScaledMMLinearLayerConfig:
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Int8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
# TODO: Change to QuantKey like FP8ScaledMMLinearLayerConfig
|
||||
is_static_input_scheme: bool
|
||||
is_channelwise: bool
|
||||
input_symmetric: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class FP8ScaledMMLinearLayerConfig(ScaledMMLinearLayerConfig):
|
||||
weight_quant_key: QuantKey
|
||||
activation_quant_key: QuantKey
|
||||
out_dtype: torch.dtype | None
|
||||
|
||||
|
||||
_FP8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_scale_ub,
|
||||
]
|
||||
_Int8ParamsT = tuple[
|
||||
torch.Tensor, # weight
|
||||
torch.Tensor, # weight_scale
|
||||
torch.Tensor | None, # input_scale,
|
||||
torch.Tensor | None, # input_zp
|
||||
torch.Tensor | None, # azp_adj
|
||||
]
|
||||
|
||||
_ParamsT = TypeVar("_ParamsT", _Int8ParamsT, _FP8ParamsT)
|
||||
_ConfigT = TypeVar("_ConfigT", bound=ScaledMMLinearLayerConfig)
|
||||
|
||||
|
||||
class ScaledMMLinearKernel(Generic[_ConfigT, _ParamsT], ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def can_implement(cls, c: _ConfigT) -> tuple[bool, str | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, c: _ConfigT, layer_param_names: Sequence[str]) -> None:
|
||||
assert self.can_implement(c)[0]
|
||||
assert self.is_supported()[0]
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
# return a covariant type in the subclass
|
||||
@abstractmethod
|
||||
def _get_layer_params(self, layer) -> _ParamsT:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FP8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[FP8ScaledMMLinearLayerConfig, _FP8ParamsT], ABC
|
||||
):
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
act_scale_descriptor = c.activation_quant_key.scale
|
||||
self.quant_fp8 = QuantFP8(
|
||||
static=act_scale_descriptor.static,
|
||||
group_shape=act_scale_descriptor.group_shape,
|
||||
num_token_padding=self.get_output_padding(),
|
||||
)
|
||||
self.fp8_dtype = current_platform.fp8_dtype()
|
||||
super().__init__(c, layer_param_names)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
pass
|
||||
|
||||
def _get_layer_params(self, layer) -> _FP8ParamsT:
|
||||
w, w_s, x_s, x_s_ub = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, x_s, None),
|
||||
getattr(layer, x_s_ub, None),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
fp8_dtype = self.fp8_dtype
|
||||
maybe_out_dtype = self.config.out_dtype
|
||||
w, w_s, x_s, x_s_ub = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_fp8_quant supports both dynamic and static quant.
|
||||
# If dynamic, layer.input_scale is None and x_s computed from x.
|
||||
# If static, layer.input_scale is scalar and x_s is input_scale.
|
||||
# View input as 2D matrix for fp8 methods
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
output_shape = [*x.shape[:-1], w.shape[1]]
|
||||
out_dtype = x.dtype if maybe_out_dtype is None else maybe_out_dtype
|
||||
|
||||
# If input not quantized
|
||||
# TODO(luka) remove this path if not used anymore
|
||||
x_2d_q = x_2d
|
||||
if x.dtype != fp8_dtype:
|
||||
x_2d_q, x_s = self.quant_fp8(
|
||||
x_2d,
|
||||
x_s,
|
||||
x_s_ub,
|
||||
)
|
||||
return self.apply_scaled_mm(
|
||||
A=x_2d_q,
|
||||
B=w,
|
||||
out_dtype=out_dtype,
|
||||
As=x_s,
|
||||
Bs=w_s,
|
||||
bias=bias,
|
||||
output_shape=output_shape,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_output_padding(self) -> int | None:
|
||||
return None
|
||||
|
||||
|
||||
class Int8ScaledMMLinearKernel(
|
||||
ScaledMMLinearKernel[Int8ScaledMMLinearLayerConfig, _Int8ParamsT], ABC
|
||||
):
|
||||
def _get_layer_params(self, layer) -> _Int8ParamsT:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self.layer_param_names
|
||||
return (
|
||||
getattr(layer, w_q),
|
||||
getattr(layer, w_s),
|
||||
getattr(layer, i_s, None),
|
||||
getattr(layer, i_zp, None),
|
||||
getattr(layer, azp_adj, None),
|
||||
)
|
||||
54
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
Normal file
54
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.aiter import (
|
||||
AiterInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cpu import (
|
||||
CPUInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
|
||||
CutlassFP8ScaledMMLinearKernel,
|
||||
CutlassInt8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.flashinfer import (
|
||||
FlashInferFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.pytorch import (
|
||||
ChannelWiseTorchFP8ScaledMMLinearKernel,
|
||||
PerTensorTorchFP8ScaledMMLinearKernel,
|
||||
RowWiseTorchFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.rocm import (
|
||||
ROCmFP8ScaledMMLinearKernel,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
ScaledMMLinearKernel,
|
||||
ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.model_executor.kernels.linear.scaled_mm.triton import (
|
||||
TritonInt8ScaledMMLinearKernel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FP8ScaledMMLinearKernel",
|
||||
"FP8ScaledMMLinearLayerConfig",
|
||||
"Int8ScaledMMLinearKernel",
|
||||
"Int8ScaledMMLinearLayerConfig",
|
||||
"ScaledMMLinearKernel",
|
||||
"ScaledMMLinearLayerConfig",
|
||||
"AiterInt8ScaledMMLinearKernel",
|
||||
"CPUInt8ScaledMMLinearKernel",
|
||||
"CutlassFP8ScaledMMLinearKernel",
|
||||
"CutlassInt8ScaledMMLinearKernel",
|
||||
"FlashInferFP8ScaledMMLinearKernel",
|
||||
"ChannelWiseTorchFP8ScaledMMLinearKernel",
|
||||
"PerTensorTorchFP8ScaledMMLinearKernel",
|
||||
"RowWiseTorchFP8ScaledMMLinearKernel",
|
||||
"ROCmFP8ScaledMMLinearKernel",
|
||||
"TritonInt8ScaledMMLinearKernel",
|
||||
]
|
||||
109
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
Normal file
109
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import Int8ScaledMMLinearLayerConfig
|
||||
|
||||
|
||||
class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "Requires ROCm."
|
||||
|
||||
if compute_capability is not None and compute_capability < 90:
|
||||
return False, "requires compute capability 90 and above."
|
||||
|
||||
try:
|
||||
import aiter # noqa: F401 # deliberately attempt to import aiter
|
||||
except Exception:
|
||||
return False, "requires `aiter` to be installed."
|
||||
|
||||
if not rocm_aiter_ops.is_linear_enabled():
|
||||
return (
|
||||
False,
|
||||
"requires setting `VLLM_ROCM_USE_AITER=1` "
|
||||
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
|
||||
"`VLLM_ROCM_USE_AITER_LINEAR` default is True.",
|
||||
)
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return False, "supports symmetric quantization only."
|
||||
return True, None
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
`AiterInt8ScaledMMLinearKernel` implements a fused version of
|
||||
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
|
||||
where scale_a * a and scale_b * b are implemented using numpy-style
|
||||
broadcasting.
|
||||
Currently only support per-tensor-per-tensor GEMM
|
||||
and per-token-per-channel GEMM through AITER
|
||||
w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` also does not support
|
||||
ATIER block scaled GEMM and mix-precision GEMM.
|
||||
"""
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
assert symmetric, (
|
||||
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
|
||||
)
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(x, i_s, i_zp, symmetric=symmetric)
|
||||
|
||||
assert x_zp is None, (
|
||||
"AiterInt8ScaledMMLinearKernel only supports symmetric quantization."
|
||||
)
|
||||
out_dtype = x.dtype
|
||||
|
||||
assert w_q.shape[0] % 16 == 0 and w_q.shape[1] % 16 == 0
|
||||
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
|
||||
assert bias is None or bias.shape[0] == w_q.shape[1] and bias.dtype == out_dtype
|
||||
|
||||
m = x_q.shape[0] # a
|
||||
n = w_q.shape[1] # b
|
||||
|
||||
per_tensor_scale_a = x_s.numel() == 1
|
||||
per_tensor_scale_b = w_s.numel() == 1
|
||||
per_token_scale_a = x_s.numel() == m
|
||||
per_channel_scale_b = w_s.numel() == n
|
||||
|
||||
# @TODO:
|
||||
# Maybe broadcast the per-tensor-scale into per-channel-scale
|
||||
# if one of the scale is a per-channel-scale.
|
||||
# For now, it only supports:
|
||||
# - per-tensor-per-tensor a8w8 scaled GEMM, and
|
||||
# - per-token-per-channel a8w8 scaled GEMM
|
||||
assert (per_tensor_scale_a and per_tensor_scale_b) or (
|
||||
per_token_scale_a and per_channel_scale_b
|
||||
), (
|
||||
"Currently only support per-tensor-per-tensor GEMM "
|
||||
" and per-token-per-channel GEMM through AITER"
|
||||
" w8a8 scaled gemm. `AiterInt8ScaledMMLinearKernel` "
|
||||
"does not support AITER block scaled GEMM."
|
||||
)
|
||||
|
||||
# gemm_a8w8_CK(a, b, scale_a, scale_b, bias) expects
|
||||
# a to be [M, K]
|
||||
# b to be [N, K]
|
||||
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
|
||||
return rocm_aiter_ops.gemm_a8w8(x_q, w_q.t(), x_s, w_s, bias, out_dtype)
|
||||
217
vllm/model_executor/kernels/linear/scaled_mm/cpu.py
Normal file
217
vllm/model_executor/kernels/linear/scaled_mm/cpu.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.model_executor.layers.utils import check_cpu_sgl_kernel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import CpuArchEnum
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class CPUInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cpu():
|
||||
return False, "requires CPU."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, _, _, _, _ = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
dtype = weight.dtype
|
||||
N, K = weight.size()
|
||||
if (
|
||||
current_platform.get_cpu_architecture() == CpuArchEnum.X86
|
||||
and envs.VLLM_CPU_SGL_KERNEL
|
||||
and self.config.input_symmetric
|
||||
and check_cpu_sgl_kernel(N, K, dtype)
|
||||
):
|
||||
self.linear_method = self._apply_weights_sgl
|
||||
self.process_weights_for_sgl(layer)
|
||||
else:
|
||||
self.linear_method = self._apply_weights_onednn
|
||||
self.process_weights_for_onednn(layer)
|
||||
|
||||
def process_weights_for_onednn(self, layer: torch.nn.Module) -> None:
|
||||
# WEIGHT
|
||||
# Transpose to [K, N] for convenience
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# oneDNN kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if self.config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
azp = (
|
||||
(int8_traits.min - range_min / scale).round().to(dtype=torch.int32)
|
||||
)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
# Different from cutlass, oneDNN kernels only need the AZP adjustment
|
||||
# term for dynamic quantization. And s_b should be folded into the
|
||||
# term. Such as:
|
||||
# s_a * s_b * [(A - zp_a)B] + bias =
|
||||
# s_a * (s_b * AB) - s_a * s_b * zp_a * B + bias =
|
||||
# s_a * GEMM_output - s_a * zp_a * adj + bias
|
||||
if not (self.config.input_symmetric and self.config.is_static_input_scheme):
|
||||
weight = getattr(layer, w_q_name)
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.float32)
|
||||
azp_adj = azp_adj * weight_scale.squeeze()
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
weight = getattr(layer, w_q_name)
|
||||
self.dnnl_handler = ops.create_onednn_scaled_mm(
|
||||
weight,
|
||||
getattr(layer, w_s_name),
|
||||
torch.get_default_dtype(),
|
||||
getattr(layer, i_s_name) is None,
|
||||
not self.config.input_symmetric,
|
||||
32,
|
||||
)
|
||||
# weight is prepacked and maintained by the dnnl_handler,
|
||||
# release the original weight
|
||||
setattr(layer, w_q_name, None)
|
||||
del weight
|
||||
|
||||
def process_weights_for_sgl(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, _, _, _ = self.layer_param_names
|
||||
# WEIGHT
|
||||
weight = getattr(layer, w_q_name)
|
||||
packed_weight = torch.ops._C.convert_weight_packed(weight)
|
||||
replace_parameter(
|
||||
layer, w_q_name, torch.nn.Parameter(packed_weight, requires_grad=False)
|
||||
)
|
||||
|
||||
if layer.bias is not None:
|
||||
bias = layer.bias
|
||||
layer.register_parameter(
|
||||
"bias_fp32", torch.nn.Parameter(bias.float().data, requires_grad=False)
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# CPU SGL kernels only support per-channel.
|
||||
# For per-tensor quant, convert to the per-channel case.
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.linear_method(
|
||||
layer,
|
||||
x,
|
||||
bias,
|
||||
)
|
||||
|
||||
def _apply_weights_onednn(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
x_shape = x.shape
|
||||
x = x.reshape(-1, x_shape[-1]) if len(x_shape) > 2 else x
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
x_q, x_s, x_zp = ops.onednn_scaled_int8_quant(
|
||||
x, i_s, i_zp, self.config.input_symmetric
|
||||
)
|
||||
|
||||
m = x.size(0)
|
||||
n = self.dnnl_handler.n
|
||||
out = torch.empty((m, n), dtype=x.dtype)
|
||||
ops.onednn_scaled_mm(self.dnnl_handler, x_q, out, x_s, x_zp, azp_adj, bias)
|
||||
out = out.reshape(x_shape[:-1] + (n,)) if len(x_shape) > 2 else out
|
||||
return out
|
||||
|
||||
def _apply_weights_sgl(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, _, _, _ = self._get_layer_params(layer)
|
||||
return torch.ops._C.int8_scaled_mm_with_quant(
|
||||
x,
|
||||
w_q,
|
||||
w_s,
|
||||
layer.bias_fp32 if bias is not None else None,
|
||||
x.dtype,
|
||||
True,
|
||||
)
|
||||
176
vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
Normal file
176
vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
Int8ScaledMMLinearKernel,
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
class CutlassInt8ScaledMMLinearKernel(Int8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
config = self.config
|
||||
# WEIGHT
|
||||
# Cutlass kernels need transposed weight.
|
||||
weight = getattr(layer, w_q_name)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
# torch.nn.Parameter(weight.t().data, requires_grad=False),
|
||||
torch.nn.Parameter(weight.data if envs.VLLM_W8A8_LINEAR_USE_W4A8 else weight.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Cutlass kernels support only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if config.is_static_input_scheme:
|
||||
input_scale = getattr(layer, i_s_name)
|
||||
|
||||
if config.input_symmetric:
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(input_scale.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
input_zero_point = getattr(layer, i_zp_name)
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
replace_parameter(
|
||||
layer, i_s_name, torch.nn.Parameter(scale, requires_grad=False)
|
||||
)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
replace_parameter(
|
||||
layer, i_zp_name, torch.nn.Parameter(azp, requires_grad=False)
|
||||
)
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/main/csrc/quantization/w8a8/cutlass/Epilogues.md
|
||||
if not config.input_symmetric:
|
||||
weight = getattr(layer, w_q_name)
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if config.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = getattr(layer, i_zp_name) * azp_adj
|
||||
setattr(
|
||||
layer,
|
||||
azp_adj_name,
|
||||
torch.nn.Parameter(azp_adj, requires_grad=False),
|
||||
)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, azp_adj = self._get_layer_params(layer)
|
||||
|
||||
# ops.scaled_int8_quant supports both dynamic and static quant:
|
||||
# * dynamic, i_s is None and x_s computed from x.
|
||||
# * static, i_s is scalar and x_s is i_s.
|
||||
symmetric = azp_adj is None
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=symmetric
|
||||
)
|
||||
|
||||
if x_zp is not None:
|
||||
# Currently, static is always per-tensor and dynamic is per-token
|
||||
static = i_zp is not None
|
||||
azp = None if static else x_zp
|
||||
return ops.cutlass_scaled_mm_azp(
|
||||
x_q,
|
||||
w_q,
|
||||
scale_a=x_s,
|
||||
scale_b=w_s,
|
||||
out_dtype=x.dtype,
|
||||
azp_adj=azp_adj,
|
||||
azp=azp,
|
||||
bias=bias,
|
||||
)
|
||||
return ops.cutlass_scaled_mm(
|
||||
# x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias, format="NN" if envs.VLLM_W8A8_LINEAR_USE_W4A8 else "TN"
|
||||
)
|
||||
|
||||
|
||||
class CutlassFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Fused GEMM_DQ
|
||||
output = ops.cutlass_scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
return output.view(*output_shape)
|
||||
57
vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
Normal file
57
vllm/model_executor/kernels/linear/scaled_mm/flashinfer.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_cuda():
|
||||
return False, "requires CUDA."
|
||||
|
||||
if not has_flashinfer():
|
||||
return False, "requires FlashInfer to be installed."
|
||||
|
||||
if compute_capability is not None and compute_capability < 100:
|
||||
return False, "requires compute capability 100 and above."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return False, "requires per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
return flashinfer_scaled_fp8_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
217
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
Normal file
217
vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationMode, get_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
"""
|
||||
Base class for FP8 linear kernels using Torch.
|
||||
Each subclass represents a kernel variant for
|
||||
specific device capabilities and torch versions.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
|
||||
return False, "requires ROCm, CUDA or CPU."
|
||||
|
||||
if compute_capability is not None and compute_capability < 89:
|
||||
return False, "requires compute capability 89 and above."
|
||||
|
||||
return True, None
|
||||
|
||||
def get_output_padding(self) -> int | None:
|
||||
# Note: we pad the input because torch._scaled_mm is more performant
|
||||
# for matrices with batch dimension > 16.
|
||||
# This could change in the future.
|
||||
# We also don't pad when using torch.compile,
|
||||
# as it breaks with dynamic shapes.
|
||||
#
|
||||
# The perf gain is still relevant as of 16/1/2026
|
||||
# torch version == 2.9.0. More details in the link below:
|
||||
# https://github.com/vllm-project/vllm/issues/32269
|
||||
vllm_config = get_current_vllm_config().compilation_config
|
||||
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
|
||||
return 17 if pad_output else None
|
||||
|
||||
|
||||
class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return False, "requires per tensor activation and weight scales."
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch._scaled_mm(
|
||||
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "requires ROCm."
|
||||
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
if not on_mi3xx():
|
||||
return False, "requires MI3xx."
|
||||
|
||||
if compute_capability is not None and compute_capability < 94:
|
||||
return False, "requires compute capability 94 and above."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if c.out_dtype == torch.float16:
|
||||
# hipblaslt rowwise _scaled_mm only supports BFloat16
|
||||
return False, "supports BFloat16 output data type only."
|
||||
|
||||
if per_tensor_activation_scales or per_tensor_weight_scales:
|
||||
return False, "cannot be used with per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Note:
|
||||
# For now it has only been validated on ROCm platform.
|
||||
# fp8 rowwise scaling in torch._scaled_mm is introduced in
|
||||
# https://github.com/pytorch/pytorch/pull/144432 using
|
||||
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
|
||||
#
|
||||
# For CUDA platform please validate if the torch._scaled_mm supports
|
||||
# rowwise scaled GEMM before using it
|
||||
|
||||
# Fused GEMM_DQ Rowwise GEMM
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs.t(),
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
|
||||
|
||||
|
||||
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if per_tensor_activation_scales and per_tensor_weight_scales:
|
||||
return False, "cannot be used with per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
# Use unfused DQ due to limitations with scaled_mm
|
||||
|
||||
# Symmetric quantized GEMM by definition computes the following:
|
||||
# C = (s_x * X) (s_w * W) + bias
|
||||
# This is equivalent to dequantizing the weights and activations
|
||||
# before applying a GEMM.
|
||||
#
|
||||
# In order to compute quantized operands, a quantized kernel
|
||||
# will rewrite the above like so:
|
||||
# C = s_w * s_x * (X * W) + bias
|
||||
#
|
||||
# For the scaled_mm fallback case, we break this down, since it
|
||||
# does not support s_w being a vector.
|
||||
|
||||
# Input scaling factors are no longer optional in _scaled_mm starting
|
||||
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
|
||||
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
|
||||
|
||||
# GEMM
|
||||
# This computes C = (X * W).
|
||||
# Output in fp32 to allow subsequent ops to happen in-place
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
scale_a=dummy_tensor,
|
||||
scale_b=dummy_tensor,
|
||||
out_dtype=torch.float32,
|
||||
)
|
||||
# A fix for discrepancy in scaled_mm which returns tuple
|
||||
# for torch < 2.5 and a single value in torch >= 2.5
|
||||
if type(output) is tuple and len(output) == 2:
|
||||
output = output[0]
|
||||
|
||||
# Unpad (undo num_token_padding)
|
||||
output = torch.narrow(output, 0, 0, output_shape[0])
|
||||
x_scale = torch.narrow(As, 0, 0, output_shape[0])
|
||||
|
||||
# DQ
|
||||
# C = sw * sx * (X * W) + bias
|
||||
output = output * x_scale * Bs.t()
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(out_dtype).view(*output_shape)
|
||||
118
vllm/model_executor/kernels/linear/scaled_mm/rocm.py
Normal file
118
vllm/model_executor/kernels/linear/scaled_mm/rocm.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
from .ScaledMMLinearKernel import (
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
A.shape[0] <= 4
|
||||
and B.shape[0] % 16 == 0 # M TODO: needed?
|
||||
and B.shape[1] % 16 == 0 # K
|
||||
and ((bias is None) or (bias.dtype == out_dtype))
|
||||
):
|
||||
output = ops.wvSplitKQ(
|
||||
B.t(),
|
||||
A,
|
||||
out_dtype,
|
||||
As,
|
||||
Bs,
|
||||
num_compute_units(),
|
||||
bias,
|
||||
)
|
||||
# Fallback
|
||||
else:
|
||||
output = torch._scaled_mm(
|
||||
A,
|
||||
B,
|
||||
out_dtype=out_dtype,
|
||||
scale_a=As,
|
||||
scale_b=Bs,
|
||||
bias=bias,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
|
||||
|
||||
|
||||
if current_platform.is_rocm():
|
||||
direct_register_custom_op(
|
||||
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
|
||||
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
|
||||
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
|
||||
)
|
||||
|
||||
|
||||
class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_rocm():
|
||||
return False, "requires ROCm."
|
||||
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
if not on_mi3xx():
|
||||
return False, "requires MI3xx."
|
||||
|
||||
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
|
||||
return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
per_tensor_activation_scales = (
|
||||
c.activation_quant_key.scale.group_shape.is_per_tensor()
|
||||
)
|
||||
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
|
||||
|
||||
if not (per_tensor_activation_scales and per_tensor_weight_scales):
|
||||
return False, "requires per tensor activation and weight scales."
|
||||
|
||||
return True, None
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
|
||||
A, B, out_dtype, As, Bs, bias
|
||||
)
|
||||
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
|
||||
93
vllm/model_executor/kernels/linear/scaled_mm/triton.py
Normal file
93
vllm/model_executor/kernels/linear/scaled_mm/triton.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa: E501
|
||||
triton_scaled_mm,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
convert_to_channelwise,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .cutlass import CutlassInt8ScaledMMLinearKernel
|
||||
from .ScaledMMLinearKernel import (
|
||||
Int8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
|
||||
|
||||
class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if current_platform.is_cuda_alike():
|
||||
return True, None
|
||||
return False, "requires ROCm or CUDA."
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if not c.input_symmetric:
|
||||
return False, "supports symmetric input only."
|
||||
return True, None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
w_q, _, i_s, _, _ = self._get_layer_params(layer)
|
||||
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
|
||||
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_q_name,
|
||||
torch.nn.Parameter(w_q.t().data, requires_grad=False),
|
||||
)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# Triton kernel supports only per-tensor and per-channel.
|
||||
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
|
||||
# scales being passed to the kernel), convert to the per-channel case.
|
||||
is_fused_module = len(layer.logical_widths) > 1
|
||||
weight_scale = getattr(layer, w_s_name)
|
||||
if is_fused_module and not self.config.is_channelwise:
|
||||
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
|
||||
replace_parameter(
|
||||
layer,
|
||||
w_s_name,
|
||||
torch.nn.Parameter(weight_scale.data, requires_grad=False),
|
||||
)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.config.is_static_input_scheme:
|
||||
assert i_s is not None
|
||||
replace_parameter(
|
||||
layer,
|
||||
i_s_name,
|
||||
torch.nn.Parameter(i_s.max(), requires_grad=False),
|
||||
)
|
||||
setattr(layer, i_zp_name, None)
|
||||
else:
|
||||
setattr(layer, i_s_name, None)
|
||||
setattr(layer, i_zp_name, None)
|
||||
|
||||
setattr(layer, azp_adj_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
|
||||
|
||||
x_q, x_s, x_zp = ops.scaled_int8_quant(
|
||||
x.contiguous(), i_s, i_zp, symmetric=True
|
||||
)
|
||||
|
||||
assert x_zp is None, "Triton kernel only supports symmetric quantization"
|
||||
|
||||
return triton_scaled_mm(
|
||||
x_q, w_q, scale_a=x_s, scale_b=w_s, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
59
vllm/model_executor/kernels/linear/scaled_mm/xpu.py
Normal file
59
vllm/model_executor/kernels/linear/scaled_mm/xpu.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.kernels.linear import ( # noqa: E501
|
||||
FP8ScaledMMLinearKernel,
|
||||
FP8ScaledMMLinearLayerConfig,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
|
||||
@classmethod
|
||||
def is_supported(
|
||||
cls, compute_capability: int | None = None
|
||||
) -> tuple[bool, str | None]:
|
||||
if not current_platform.is_xpu():
|
||||
return False, "XPUFP8ScaledMM only support on XPU"
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
|
||||
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
|
||||
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
|
||||
return True, None
|
||||
|
||||
def __init__(
|
||||
self, c: FP8ScaledMMLinearLayerConfig, layer_param_names: Sequence[str]
|
||||
) -> None:
|
||||
assert self.can_implement(c)[0]
|
||||
assert self.is_supported()[0]
|
||||
self.config = c
|
||||
self.layer_param_names = layer_param_names
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale
|
||||
return torch.ops._xpu_C.fp8_gemm_w8a16(x, weight, weight_scale, bias)
|
||||
|
||||
def apply_scaled_mm(
|
||||
self,
|
||||
*,
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
bias: torch.Tensor | None,
|
||||
output_shape: list,
|
||||
) -> torch.Tensor:
|
||||
pass
|
||||
0
vllm/model_executor/layers/__init__.py
Normal file
0
vllm/model_executor/layers/__init__.py
Normal file
708
vllm/model_executor/layers/activation.py
Normal file
708
vllm/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,708 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom activation functions."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.collection_utils import LazyDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _swiglustep_and_mul_kernel(
|
||||
o_ptr,
|
||||
o_stride,
|
||||
x_ptr,
|
||||
x_stride,
|
||||
limit: tl.constexpr,
|
||||
d: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
) -> None:
|
||||
i = tl.program_id(axis=0).to(tl.int64)
|
||||
j = tl.program_id(axis=1)
|
||||
o_row_ptr = o_ptr + o_stride * i
|
||||
x_row_ptr = x_ptr + x_stride * i
|
||||
offsets = j * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < d
|
||||
|
||||
gate = tl.load(x_row_ptr + offsets, mask=mask).to(tl.float32)
|
||||
up = tl.load(x_row_ptr + offsets + d, mask=mask).to(tl.float32)
|
||||
|
||||
gate_silu = tl.sigmoid(gate) * gate
|
||||
gate_clamped = tl.minimum(gate_silu, limit)
|
||||
up_clamped = tl.minimum(tl.maximum(up, -limit), limit)
|
||||
|
||||
result = gate_clamped * up_clamped
|
||||
result = result.to(x_ptr.dtype.element_ty)
|
||||
tl.store(o_row_ptr + offsets, result, mask=mask)
|
||||
|
||||
|
||||
def swiglustep_and_mul_triton(
|
||||
output: torch.Tensor, input: torch.Tensor, limit: float = 7.0
|
||||
):
|
||||
b, n = input.shape
|
||||
assert input.ndim == 2
|
||||
assert n % 2 == 0
|
||||
d = n // 2
|
||||
|
||||
def grid(meta):
|
||||
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
|
||||
|
||||
_swiglustep_and_mul_kernel[grid](
|
||||
output,
|
||||
output.stride(0),
|
||||
input,
|
||||
input.stride(0),
|
||||
limit=limit,
|
||||
d=d,
|
||||
BLOCK_SIZE=1024,
|
||||
)
|
||||
|
||||
|
||||
# --8<-- [start:fatrelu_and_mul]
|
||||
@CustomOp.register("fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
|
||||
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||
d = x.shape[-1] // 2.
|
||||
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:fatrelu_and_mul]
|
||||
|
||||
def __init__(self, threshold: float = 0.0):
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.fatrelu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
x1 = F.threshold(x1, self.threshold, 0.0)
|
||||
return x1 * x2
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x, self.threshold)
|
||||
return out
|
||||
|
||||
|
||||
# --8<-- [start:silu_and_mul]
|
||||
@CustomOp.register("silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:silu_and_mul]
|
||||
|
||||
def __init__(self, *, compile_native: bool = True):
|
||||
super().__init__(compile_native=compile_native)
|
||||
if current_platform.is_cuda_alike():
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.silu_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
@staticmethod
|
||||
def forward_native(x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:mul_and_silu]
|
||||
@CustomOp.register("mul_and_silu")
|
||||
class MulAndSilu(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:mul_and_silu]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_xpu():
|
||||
# self.op = torch.ops._C.mul_and_silu
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.mul_and_silu
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return x[..., :d] * F.silu(x[..., d:])
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_and_mul_sparse]
|
||||
@CustomOp.register("gelu_and_mul_sparse")
|
||||
class GeluAndMulSparse(CustomOp):
|
||||
"""An activation function for GeluAndMulSparse.
|
||||
This activation function is used in Gemma3n. It computes:
|
||||
up_proj = self.up_proj(x)
|
||||
gate_proj = self.gate_proj(x)
|
||||
gate_proj = self._gaussian_topk(gate_proj) # sparsity
|
||||
activations = self.act_fn(gate_proj) # gelu
|
||||
down_proj = self.down_proj(activations * up_proj)
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:gelu_and_mul_sparse]
|
||||
|
||||
def __init__(self, activation_sparsity: float, approximate: str = "none"):
|
||||
super().__init__()
|
||||
# Gelu.
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
if current_platform.is_rocm() and approximate == "tanh":
|
||||
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||
logger.warning_once(
|
||||
"[ROCm] Pytorch's native GELU with tanh approximation is currently "
|
||||
"unstable and produces garbage. Fallback to 'none' approximation."
|
||||
)
|
||||
self.approximate = "none"
|
||||
|
||||
# Sparsity.
|
||||
if activation_sparsity == 0.0:
|
||||
raise ValueError("activation_sparsity is 0.0. Please use GeluAndMul.")
|
||||
target_sparsity_tensor = torch.tensor(activation_sparsity, dtype=torch.float32)
|
||||
normal_dist = torch.distributions.normal.Normal(0, 1)
|
||||
self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
|
||||
|
||||
def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Get % sparse percentile of the Gaussian distribution."""
|
||||
# NOTE(rob): for TP>1, we could all-gather to get the means/std.
|
||||
# But we do not do this because in expectation they are the same
|
||||
# and in practice the eval scores are good without gathering.
|
||||
mean = torch.mean(x, dim=-1, keepdim=True)
|
||||
std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
|
||||
cutoff_x = mean + std * self.std_multiplier
|
||||
return nn.functional.relu(x - cutoff_x)
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
out = self._gaussian_topk(x[..., :d])
|
||||
out = F.gelu(out, approximate=self.approximate)
|
||||
return out * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_and_mul]
|
||||
@CustomOp.register("gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
# --8<-- [end:gelu_and_mul]
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
if approximate == "none":
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_and_mul
|
||||
elif approximate == "tanh":
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_tanh_and_mul
|
||||
if current_platform.is_rocm() and approximate == "tanh":
|
||||
logger.warning_once(
|
||||
"[ROCm] PyTorch's native GELU with tanh approximation is unstable "
|
||||
"with torch.compile. For native implementation, fallback to 'none' "
|
||||
"approximation. The custom kernel implementation is unaffected."
|
||||
)
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
# TODO: [ROCm] PyTorch's native GELU with tanh is unstable with torch.compile
|
||||
approximate = self.approximate
|
||||
if current_platform.is_rocm() and approximate == "tanh":
|
||||
approximate = "none"
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"approximate={repr(self.approximate)}"
|
||||
|
||||
|
||||
# --8<-- [start:swigluoai_and_mul]
|
||||
@CustomOp.register("swigluoai_and_mul")
|
||||
class SwigluOAIAndMul(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
|
||||
# --8<-- [end:swigluoai_and_mul]
|
||||
|
||||
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.limit = limit
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
|
||||
gate, up = x[..., ::2], x[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
|
||||
|
||||
|
||||
# --8<-- [start:swiglustep_and_mul]
|
||||
@CustomOp.register("swiglustep_and_mul")
|
||||
class SwigluStepAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU with clamping.
|
||||
|
||||
Computes x -> silu(x[:d]).clamp(max=limit) * x[d:].clamp(-limit, limit)
|
||||
where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, limit: float = 7.0):
|
||||
super().__init__()
|
||||
if limit is None:
|
||||
raise ValueError("SwigluStepAndMul requires limit to be set.")
|
||||
self.limit = limit
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
gate, up = x.chunk(2, dim=-1)
|
||||
gate = F.silu(gate)
|
||||
gate = gate.clamp(max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
return gate * up
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = x.shape[:-1] + (d,)
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
swiglustep_and_mul_triton(out, x, self.limit)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"limit={repr(self.limit)}"
|
||||
|
||||
|
||||
# --8<-- [start:gelu_new]
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
# --8<-- [end:gelu_new]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
# self.op = torch.ops._C.gelu_new
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_new
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c * (x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:gelu_fast]
|
||||
@CustomOp.register("gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
# --8<-- [end:gelu_fast]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
self.op = torch.ops._C.gelu_fast
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:quick_gelu]
|
||||
@CustomOp.register("quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
# --8<-- [end:quick_gelu]
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if (
|
||||
current_platform.is_cuda_alike()
|
||||
or current_platform.is_cpu()
|
||||
or current_platform.is_xpu()
|
||||
):
|
||||
# self.op = torch.ops._C.gelu_quick
|
||||
from vllm import _custom_ops as ops
|
||||
self.op = ops.gelu_quick
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_cuda(x)
|
||||
|
||||
|
||||
# --8<-- [start:relu2]
|
||||
@CustomOp.register("relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
# --8<-- [end:relu2]
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO : implement cuda kernels
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
# --8<-- [start:xielu]
|
||||
@CustomOp.register("xielu")
|
||||
class XIELU(CustomOp):
|
||||
"""
|
||||
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
||||
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
|
||||
Otherwise, we emit a single warning and use xIELU Python
|
||||
"""
|
||||
|
||||
# --8<-- [end:xielu]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha_p_init: float = 0.8,
|
||||
alpha_n_init: float = 0.8,
|
||||
beta: float = 0.5,
|
||||
eps: float = -1e-6,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
with_vector_loads: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_p = nn.Parameter(
|
||||
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
|
||||
0
|
||||
)
|
||||
)
|
||||
self.alpha_n = nn.Parameter(
|
||||
torch.log(
|
||||
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
|
||||
).unsqueeze(0)
|
||||
)
|
||||
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
||||
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
||||
self.with_vector_loads = with_vector_loads
|
||||
# Temporary until xIELU CUDA fully implemented
|
||||
self._beta_scalar = float(self.beta.detach().cpu().float().item())
|
||||
self._eps_scalar = float(self.eps.detach().cpu().float().item())
|
||||
|
||||
self._xielu_cuda_obj = None
|
||||
try:
|
||||
import xielu.ops # noqa: F401
|
||||
|
||||
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
||||
msg = "Using experimental xIELU CUDA."
|
||||
try:
|
||||
from torch._dynamo import allow_in_graph
|
||||
|
||||
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
||||
msg += " Enabled torch._dynamo for xIELU CUDA."
|
||||
except Exception as err:
|
||||
msg += (
|
||||
f" Could not enable torch._dynamo for xIELU ({err}) - "
|
||||
"this may result in slower performance."
|
||||
)
|
||||
self._xielu_cuda_fn = self._xielu_cuda
|
||||
logger.warning_once(msg)
|
||||
except Exception as err:
|
||||
logger.warning_once(
|
||||
"CUDA-fused xIELU not available (%s) –"
|
||||
" falling back to a Python version.\n"
|
||||
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
||||
str(err),
|
||||
)
|
||||
|
||||
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha_p = nn.functional.softplus(self.alpha_p)
|
||||
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
||||
return torch.where(
|
||||
x > 0,
|
||||
alpha_p * x * x + self.beta * x,
|
||||
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
|
||||
)
|
||||
|
||||
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Firewall function to prevent torch.compile from seeing .item()"""
|
||||
assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
|
||||
original_shape = x.shape
|
||||
# CUDA kernel expects 3D tensors, reshape if needed
|
||||
while x.dim() < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 3:
|
||||
x = x.view(-1, 1, x.size(-1))
|
||||
if original_shape != x.shape:
|
||||
logger.warning_once(
|
||||
"Warning: xIELU input tensor expects 3 dimensions"
|
||||
" but got (shape: %s). Reshaping to (shape: %s).",
|
||||
original_shape,
|
||||
x.shape,
|
||||
)
|
||||
result = self._xielu_cuda_obj.forward(
|
||||
x,
|
||||
self.alpha_p,
|
||||
self.alpha_n,
|
||||
# Temporary until xIELU CUDA fully implemented ->
|
||||
# self.{beta,eps}.item()
|
||||
self._beta_scalar,
|
||||
self._eps_scalar,
|
||||
self.with_vector_loads,
|
||||
)
|
||||
return result.view(original_shape)
|
||||
|
||||
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||||
if not torch._dynamo.is_compiling():
|
||||
return self._xielu_cuda_fn(input)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"torch._dynamo is compiling, using Python version of xIELU."
|
||||
)
|
||||
return self._xielu_python(input)
|
||||
|
||||
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(input)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size, tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
|
||||
)
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = LazyDict(
|
||||
{
|
||||
"gelu": lambda: nn.GELU(),
|
||||
"gelu_fast": lambda: FastGELU(),
|
||||
"gelu_new": lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh": lambda: (
|
||||
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
|
||||
logger.warning_once(
|
||||
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
|
||||
"Falling back to GELU(approximate='none')."
|
||||
),
|
||||
nn.GELU(approximate="none"),
|
||||
)[1]
|
||||
if current_platform.is_rocm()
|
||||
else nn.GELU(approximate="tanh"),
|
||||
"relu": lambda: nn.ReLU(),
|
||||
"relu2": lambda: ReLUSquaredActivation(),
|
||||
"silu": lambda: nn.SiLU(),
|
||||
"quick_gelu": lambda: QuickGELU(),
|
||||
"tanh": lambda: nn.Tanh(),
|
||||
"sigmoid": lambda: nn.Sigmoid(),
|
||||
"xielu": lambda: XIELU(),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
|
||||
if act_fn_name.startswith("torch.nn.modules."):
|
||||
activation_name = act_fn_name.split(".")[-1]
|
||||
if activation_name == "identity":
|
||||
return nn.Identity()
|
||||
act_fn_name = activation_name
|
||||
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
return _ACTIVATION_REGISTRY[act_fn_name]
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
|
||||
{
|
||||
"gelu": lambda: GeluAndMul(),
|
||||
"silu": lambda: SiluAndMul(),
|
||||
"geglu": lambda: GeluAndMul(),
|
||||
"swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
26
vllm/model_executor/layers/attention/__init__.py
Normal file
26
vllm/model_executor/layers/attention/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.model_executor.layers.attention.attention import Attention
|
||||
from vllm.model_executor.layers.attention.chunked_local_attention import (
|
||||
ChunkedLocalAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.attention.cross_attention import CrossAttention
|
||||
from vllm.model_executor.layers.attention.encoder_only_attention import (
|
||||
EncoderOnlyAttention,
|
||||
)
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLAAttention
|
||||
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.attention.static_sink_attention import (
|
||||
StaticSinkAttention,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Attention",
|
||||
"ChunkedLocalAttention",
|
||||
"CrossAttention",
|
||||
"EncoderOnlyAttention",
|
||||
"MLAAttention",
|
||||
"MMEncoderAttention",
|
||||
"StaticSinkAttention",
|
||||
]
|
||||
733
vllm/model_executor/layers/attention/attention.py
Normal file
733
vllm/model_executor/layers/attention/attention.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.kv_transfer_utils import (
|
||||
maybe_transfer_kv_layer,
|
||||
)
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
|
||||
from vllm.model_executor.layers.linear import (
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionType,
|
||||
)
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.attention import MLAAttention
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def validate_kv_sharing_target(
|
||||
current_layer_name, target_layer_name, static_forward_context
|
||||
):
|
||||
error_msg = (
|
||||
f"Specified KV sharing target layer for {current_layer_name} "
|
||||
f"is not valid: target layer {target_layer_name} "
|
||||
)
|
||||
|
||||
if current_layer_name == target_layer_name:
|
||||
raise ValueError(error_msg + "cannot be the same as the current layer.")
|
||||
|
||||
if target_layer_name not in static_forward_context:
|
||||
from vllm.model_executor.models.utils import extract_layer_index
|
||||
|
||||
# If target layer name is not in the static fwd context, it means either
|
||||
# a) the target layer does not come BEFORE the current layer, or
|
||||
# b) the target layer is not an Attention layer that exists in the model
|
||||
current_layer_idx = extract_layer_index(current_layer_name)
|
||||
target_layer_idx = extract_layer_index(target_layer_name)
|
||||
if current_layer_idx <= target_layer_idx:
|
||||
raise ValueError(error_msg + "must come before the current layer.")
|
||||
else:
|
||||
raise ValueError(error_msg + "is not a valid Attention layer in the model.")
|
||||
|
||||
# Currently KV sharing is only supported between layers of the same type
|
||||
target_layer_attn_type = static_forward_context[target_layer_name].attn_type
|
||||
expected = static_forward_context[current_layer_name].attn_type
|
||||
if target_layer_attn_type != expected:
|
||||
raise ValueError(
|
||||
error_msg + f"must be the same type as the current layer ({expected})."
|
||||
)
|
||||
|
||||
|
||||
def should_load_quant_weights(quant_method: QuantizeMethodBase | None) -> bool:
|
||||
"""Returns whether the quantization method should load quantized weights."""
|
||||
return quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod
|
||||
)
|
||||
|
||||
|
||||
def set_default_quant_scales(layer: nn.Module, register_buffer: bool = False) -> None:
|
||||
"""Sets default quantization scales for the layer."""
|
||||
if register_buffer:
|
||||
layer.register_buffer("_k_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_v_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_q_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
layer.register_buffer("_prob_scale", torch.tensor(1.0, dtype=torch.float32))
|
||||
else:
|
||||
layer._k_scale.fill_(1.0)
|
||||
layer._v_scale.fill_(1.0)
|
||||
layer._q_scale.fill_(1.0)
|
||||
layer._prob_scale.fill_(1.0)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
layer._q_scale_float = 1.0
|
||||
layer._k_scale_float = 1.0
|
||||
layer._v_scale_float = 1.0
|
||||
layer._prob_scale_float = 1.0
|
||||
|
||||
# Initialize q/k/v range constants used by calc_kv_scales
|
||||
layer.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
layer.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
|
||||
|
||||
def _init_kv_cache_quant(
|
||||
layer: nn.Module,
|
||||
quant_config: QuantizationConfig | None,
|
||||
prefix: str,
|
||||
) -> None:
|
||||
"""Initializes KV cache scaling factors and quantization method.
|
||||
|
||||
This helper function sets up the KV cache quantization attributes that are
|
||||
shared between Attention and MLAAttention layers. It initializes scale
|
||||
tensors for query, key, value, and probability, and configures the
|
||||
quantization method if applicable.
|
||||
|
||||
Args:
|
||||
layer: The attention layer instance to initialize.
|
||||
quant_config: Optional quantization configuration.
|
||||
prefix: Layer name prefix for quantization method lookup.
|
||||
"""
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# Note [Register q/k/v/prob scales in state dict]
|
||||
# When calling model.to(device), only parameters/buffers in state dict are
|
||||
# moved. If not registering q/k/v/prob scales in state dict, there would
|
||||
# be an IMA error when a cuda kernel (e.g., quant_fp8) accesses the tensor
|
||||
# on cpu.
|
||||
# Registering in state dict means it interacts with weight loading. One edge
|
||||
# case is when quant_method is None, or quant_method is UnquantizedLinearMethod
|
||||
# (i.e., should_load_quant_weights(quant_method) == False).
|
||||
# In this case, the checkpoint does not have the scales. We need to
|
||||
# initialize the scales to 1.0 and update the scales after weight loading.
|
||||
# This is espectially important when we load dummy weights first (providing
|
||||
# wrong scales) and then load real weights (which misses scales and keeps the
|
||||
# wrong scales from dummy load).
|
||||
set_default_quant_scales(layer, register_buffer=True)
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
layer._o_scale_float = None
|
||||
|
||||
quant_method = (
|
||||
quant_config.get_quant_method(layer, prefix=prefix) if quant_config else None
|
||||
)
|
||||
|
||||
# See [Note: Register q/k/v/prob scales in state dict]
|
||||
if should_load_quant_weights(quant_method):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if layer.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
layer.quant_method = quant_method
|
||||
layer.quant_method.create_weights(layer)
|
||||
|
||||
|
||||
class Attention(nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
use_alibi_sqrt: bool | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
logits_soft_cap: float | None = None,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
head_size_v: int | None = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
|
||||
# llm-compressor mdls need to set cache_dtype to "fp8" manually.
|
||||
kv_cache_scheme = getattr(quant_config, "kv_cache_scheme", None)
|
||||
if kv_cache_scheme is not None:
|
||||
kv_cache_dtype = "fp8"
|
||||
calculate_kv_scales = False
|
||||
if cache_config is not None:
|
||||
cache_config.cache_dtype = "fp8"
|
||||
cache_config.calculate_kv_scales = False
|
||||
|
||||
# Check if per-head quant scales are required based on kv_cache_scheme
|
||||
use_per_head_quant_scales = (
|
||||
kv_cache_scheme is not None
|
||||
and kv_cache_scheme.get("strategy") == "attn_head"
|
||||
)
|
||||
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, (
|
||||
f"num_heads ({num_heads}) is not divisible by num_kv_heads ({num_kv_heads})"
|
||||
)
|
||||
self.quant_config = quant_config
|
||||
self.layer_name = prefix
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.head_size_v = self.head_size if head_size_v is None else head_size_v
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
# NOTE: model_config may be None during certain tests
|
||||
model_config = vllm_config.model_config
|
||||
self.use_mm_prefix = model_config is not None and model_config.is_mm_prefix_lm
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
self.attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=False,
|
||||
has_sink=self.has_sink,
|
||||
use_mm_prefix=self.use_mm_prefix,
|
||||
use_per_head_quant_scales=use_per_head_quant_scales,
|
||||
attn_type=attn_type,
|
||||
)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
backend_supports_alibi_sqrt = self.attn_backend.supports_alibi_sqrt()
|
||||
use_alibi_sqrt = use_alibi_sqrt if use_alibi_sqrt else False
|
||||
if use_alibi_sqrt and not backend_supports_alibi_sqrt:
|
||||
raise ValueError(
|
||||
f"use_alibi_sqrt is not supported by backend "
|
||||
f"{self.attn_backend.get_name()}."
|
||||
)
|
||||
self.use_alibi_sqrt = bool(use_alibi_sqrt)
|
||||
if backend_supports_alibi_sqrt:
|
||||
extra_impl_args["use_alibi_sqrt"] = self.use_alibi_sqrt
|
||||
# prefix caching + batch invariance is currently not supported for
|
||||
# FLASHINFER and TRITON_MLA.
|
||||
if (
|
||||
cache_config is not None
|
||||
and cache_config.enable_prefix_caching
|
||||
and vllm_is_batch_invariant()
|
||||
and (
|
||||
self.attn_backend.get_name() == "FLASHINFER"
|
||||
or self.attn_backend.get_name() == "TRITON_MLA"
|
||||
)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Disabling prefix caching for FLASHINFER/TRITON_MLA "
|
||||
"with batch invariance, as it is not yet supported.",
|
||||
scope="local",
|
||||
)
|
||||
cache_config.enable_prefix_caching = False
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(
|
||||
num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
**extra_impl_args,
|
||||
)
|
||||
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Initialize KV cache quantization attributes
|
||||
_init_kv_cache_quant(self, quant_config, prefix)
|
||||
|
||||
# for attn backends supporting query quantization
|
||||
self.query_quant = None
|
||||
if self.impl.supports_quant_query_input and self.kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
):
|
||||
is_per_head = (
|
||||
hasattr(self, "q_scale") and self.q_scale.numel() == self.num_kv_heads
|
||||
)
|
||||
block_size = self.head_size * self.num_heads // self.num_kv_heads
|
||||
self.query_quant = QuantFP8(
|
||||
static=True,
|
||||
group_shape=GroupShape(-1, block_size)
|
||||
if is_per_head
|
||||
else GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
# For some alternate attention backends like MLA the attention output
|
||||
# shape does not match the query shape, so we optionally let the model
|
||||
# definition specify the output tensor shape.
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
|
||||
Attention metadata (`attn_metadata`) is set using a context manager in
|
||||
the model runner's `execute_model` method. It is accessed via forward
|
||||
context using
|
||||
`vllm.forward_context.get_forward_context().attn_metadata`.
|
||||
"""
|
||||
if self.calculate_kv_scales:
|
||||
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
|
||||
output_dtype = query.dtype
|
||||
if self.query_quant is not None:
|
||||
# quantizing with a simple torch operation enables
|
||||
# torch.compile to fuse this into previous ops
|
||||
# which reduces overheads during decoding.
|
||||
# Otherwise queries are quantized using custom ops
|
||||
# which causes decoding overheads
|
||||
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
|
||||
|
||||
# check if query quantization is supported
|
||||
if self.impl.supports_quant_query_input:
|
||||
query, _ = self.query_quant(query, self._q_scale)
|
||||
|
||||
if self.use_output:
|
||||
if output_shape is None:
|
||||
# Handle both 2D [num_tokens, hidden] and
|
||||
# 3D [num_tokens, heads, head_dim] query
|
||||
num_tokens = query.shape[0]
|
||||
output_shape = torch.Size(
|
||||
(num_tokens, self.num_heads * self.head_size_v)
|
||||
)
|
||||
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
|
||||
hidden_size = output_shape[-1]
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the
|
||||
# CPU overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
output = output.view(-1, self.num_heads, self.head_size_v)
|
||||
if key is not None:
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
if value is not None:
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size_v)
|
||||
kv_cache_dummy_dep = None
|
||||
if self.use_direct_call:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
else:
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
if (
|
||||
not self.attn_backend.forward_includes_kv_cache_update
|
||||
and self.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
|
||||
key, value, self.layer_name
|
||||
)
|
||||
torch.ops.vllm.unified_attention_with_output(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
self.layer_name,
|
||||
kv_cache_dummy_dep=kv_cache_dummy_dep,
|
||||
)
|
||||
return output.view(-1, hidden_size)
|
||||
else:
|
||||
assert self.attn_backend.forward_includes_kv_cache_update, (
|
||||
"Split KV cache update not supported when output tensor not provided."
|
||||
)
|
||||
if self.use_direct_call:
|
||||
return unified_attention(query, key, value, self.layer_name)
|
||||
else:
|
||||
return torch.ops.vllm.unified_attention(
|
||||
query, key, value, self.layer_name
|
||||
)
|
||||
|
||||
def calc_kv_scales(self, query, key, value):
|
||||
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
|
||||
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
|
||||
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
|
||||
self._q_scale_float = self._q_scale.item()
|
||||
self._k_scale_float = self._k_scale.item()
|
||||
self._v_scale_float = self._v_scale.item()
|
||||
# We only calculate the scales once
|
||||
self.calculate_kv_scales = False
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
s += f", num_heads={self.impl.num_heads}" # type: ignore
|
||||
s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore
|
||||
s += f", scale={self.impl.scale}" # type: ignore
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
self.impl.process_weights_after_loading(act_dtype)
|
||||
|
||||
# If we should not load quant weights, we initialize the scales to 1.0
|
||||
# as the default value. See [Note: Register q/k/v/prob scales in state dict]
|
||||
# for more details.
|
||||
quant_method = (
|
||||
self.quant_config.get_quant_method(self, prefix=self.layer_name)
|
||||
if self.quant_config
|
||||
else None
|
||||
)
|
||||
if not should_load_quant_weights(quant_method):
|
||||
set_default_quant_scales(self, register_buffer=False)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
)
|
||||
return SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
|
||||
# Only calculate if the layer's calculate_kv_scales flag is True
|
||||
# This flag gets set to False after the first forward pass
|
||||
if not self.calculate_kv_scales:
|
||||
return
|
||||
|
||||
self.calc_kv_scales(query, key, value)
|
||||
|
||||
|
||||
def maybe_calc_kv_scales_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_calc_kv_scales",
|
||||
op_func=maybe_calc_kv_scales,
|
||||
mutates_args=["query", "key", "value"],
|
||||
fake_impl=maybe_calc_kv_scales_fake,
|
||||
)
|
||||
|
||||
|
||||
def get_attention_context(
|
||||
layer_name: str,
|
||||
) -> tuple[Any, "Attention | MLAAttention", torch.Tensor, torch.Tensor]:
|
||||
"""Extract attention context for a given layer.
|
||||
|
||||
This helper function extracts the attention metadata, attention layer
|
||||
instance, KV cache tensor, and slot mapping for a specific layer.
|
||||
|
||||
Args:
|
||||
layer_name: The name/identifier of the attention layer.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- attn_metadata: Attention metadata for this specific layer, or None if
|
||||
no metadata available
|
||||
- attn_layer: The attention layer instance (Attention or MLAAttention)
|
||||
- kv_cache: The KV cache tensor for current virtual engine
|
||||
- slot_mapping: The slot mapping for this specific layer
|
||||
|
||||
Note: attn_metadata may be None, but attn_layer and kv_cache are always
|
||||
extracted from the forward context.
|
||||
"""
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if isinstance(attn_metadata, dict):
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
|
||||
kv_cache = attn_layer.kv_cache[forward_context.virtual_engine]
|
||||
slot_mapping = forward_context.slot_mapping
|
||||
assert isinstance(slot_mapping, dict), (
|
||||
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
|
||||
)
|
||||
layer_slot_mapping = slot_mapping.get(layer_name)
|
||||
return attn_metadata, attn_layer, kv_cache, layer_slot_mapping
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
output = self.impl.forward(self, query, key, value, kv_cache, attn_metadata)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def unified_attention_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(query).contiguous()
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention",
|
||||
op_func=unified_attention,
|
||||
fake_impl=unified_attention_fake,
|
||||
)
|
||||
|
||||
|
||||
def unified_kv_cache_update(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Returns a dummy that is passed to unified_attention to signal a side effect and
|
||||
the data dependency between them to ensure torch.compile preserves ordering.
|
||||
"""
|
||||
_, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name)
|
||||
if layer_slot_mapping is not None:
|
||||
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
|
||||
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
|
||||
)
|
||||
attn_layer.impl.do_kv_cache_update(
|
||||
attn_layer,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
layer_slot_mapping,
|
||||
)
|
||||
|
||||
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
|
||||
|
||||
|
||||
def unified_kv_cache_update_fake(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(0, device=key.device, dtype=key.dtype)
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_kv_cache_update",
|
||||
op_func=unified_kv_cache_update,
|
||||
fake_impl=unified_kv_cache_update_fake,
|
||||
mutates_args=[],
|
||||
)
|
||||
|
||||
|
||||
@maybe_transfer_kv_layer
|
||||
def unified_attention_with_output(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
# kv_cache_dummy_dep is not used but accepting it creates a data dependency
|
||||
# that ensures torch.compile preserves ordering between KV cache update and
|
||||
# attention forward.
|
||||
del kv_cache_dummy_dep
|
||||
attn_metadata, self, kv_cache, _ = get_attention_context(layer_name)
|
||||
|
||||
self.impl.forward(
|
||||
self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output=output,
|
||||
output_scale=output_scale,
|
||||
output_block_scale=output_block_scale,
|
||||
)
|
||||
|
||||
|
||||
def unified_attention_with_output_fake(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
kv_cache_dummy_dep: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="unified_attention_with_output",
|
||||
op_func=unified_attention_with_output,
|
||||
mutates_args=["output", "output_block_scale"],
|
||||
fake_impl=unified_attention_with_output_fake,
|
||||
)
|
||||
130
vllm/model_executor/layers/attention/chunked_local_attention.py
Normal file
130
vllm/model_executor/layers/attention/chunked_local_attention.py
Normal file
@@ -0,0 +1,130 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
KVCacheSpec,
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_chunked_local_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
attention_chunk_size: int,
|
||||
block_size: int,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_{block_size}_"
|
||||
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
assert issubclass(underlying_builder, AttentionMetadataBuilder)
|
||||
|
||||
class ChunkedLocalAttentionBuilder(underlying_builder): # type: ignore
|
||||
@classmethod
|
||||
def get_cudagraph_support(
|
||||
cls: type["AttentionMetadataBuilder"],
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
) -> AttentionCGSupport:
|
||||
# Explicit override in case the underlying builder specialized this getter.
|
||||
# @override omitted only because of mypy limitation due to type variable.
|
||||
return AttentionCGSupport.NEVER
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
):
|
||||
cm, make_virtual_batches_block_table = make_local_attention_virtual_batches(
|
||||
attention_chunk_size, common_attn_metadata, block_size
|
||||
)
|
||||
metadata = super().build(common_prefix_len, cm, fast_build)
|
||||
metadata.make_virtual_batches_block_table = make_virtual_batches_block_table
|
||||
return metadata
|
||||
|
||||
def update_block_table(
|
||||
self, metadata, blk_table: torch.Tensor, slot_mapping: torch.Tensor
|
||||
):
|
||||
blk_table = metadata.make_virtual_batches_block_table(blk_table)
|
||||
return super().update_block_table(metadata, blk_table, slot_mapping)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=ChunkedLocalAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class ChunkedLocalAttention(Attention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
attention_chunk_size: int,
|
||||
num_kv_heads: int | None = None,
|
||||
alibi_slopes: list[float] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_chunked_local_attention_backend(
|
||||
underlying_attn_backend, attention_chunk_size, block_size
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
assert self.attention_chunk_size
|
||||
return ChunkedLocalAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
226
vllm/model_executor/layers/attention/cross_attention.py
Normal file
226
vllm/model_executor/layers/attention/cross_attention.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend_with_overrides,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_cross_slot_mapping(
|
||||
encoder_seq_lens: np.ndarray,
|
||||
block_table_tensor: torch.Tensor,
|
||||
kv_cache_spec: CrossAttentionSpec,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
"""Get cross-attention slot mappings."""
|
||||
|
||||
block_size = kv_cache_spec.block_size
|
||||
slot_mappings = []
|
||||
|
||||
# Find indices with non-zero encoder sequence lengths
|
||||
# The majority of parallel requests will be running the
|
||||
# decoder, so this list should be relatively small.
|
||||
active_indices = np.nonzero(encoder_seq_lens)[0]
|
||||
|
||||
for req_index in active_indices:
|
||||
encoder_seq_len = encoder_seq_lens[req_index].item()
|
||||
|
||||
# Calculate the number of blocks needed for this request
|
||||
num_blocks_needed = cdiv(encoder_seq_len, block_size)
|
||||
|
||||
# Get the block IDs for this request from the tensor
|
||||
req_block_ids = block_table_tensor[req_index]
|
||||
|
||||
# Get only the blocks we need (first num_blocks_needed blocks)
|
||||
needed_block_ids = req_block_ids[:num_blocks_needed]
|
||||
|
||||
# All needed blocks are allocated
|
||||
i_values = torch.arange(encoder_seq_len, dtype=torch.int64, device=device)
|
||||
block_indices = i_values // block_size
|
||||
block_offsets = i_values % block_size
|
||||
block_numbers = needed_block_ids[block_indices]
|
||||
slot_mapping = block_numbers * block_size + block_offsets
|
||||
|
||||
slot_mappings.append(slot_mapping)
|
||||
|
||||
if slot_mappings:
|
||||
return torch.cat(slot_mappings)
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int64, device=device)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_cross_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "CrossAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
underlying_impl = underlying_attn_backend.get_impl_cls()
|
||||
|
||||
class CrossAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_metadata = copy(common_attn_metadata)
|
||||
new_metadata.causal = False
|
||||
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
|
||||
new_metadata.max_seq_len = max_encoder_len
|
||||
# Any computed tokens indicated decode step>1 (no chunked prefill)
|
||||
num_cache_decodes = (
|
||||
(common_attn_metadata.num_computed_tokens_cpu > 0).sum().item()
|
||||
)
|
||||
if num_cache_decodes > 0:
|
||||
# CrossAttn KV cache has already been populated on first decoder step,
|
||||
# skip slot_mapping calculation for requests that do not need
|
||||
# reshape_and_cache.
|
||||
num_tokens = common_attn_metadata.num_computed_tokens_cpu.numpy()
|
||||
new_metadata.encoder_seq_lens_cpu = np.where(
|
||||
num_tokens > 0, 0, new_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# seq_lens is provided by model runner: initial encoder input length is
|
||||
# needed here to know how many tokens to attend to from the cached
|
||||
# cross-attention KV cache.
|
||||
new_metadata.seq_lens = common_attn_metadata.encoder_seq_lens
|
||||
new_metadata._seq_lens_cpu = torch.from_numpy(
|
||||
common_attn_metadata.encoder_seq_lens_cpu
|
||||
)
|
||||
|
||||
# NOTE (NickLucche) use `new_metadata` instead of `common_*` (initial) here
|
||||
slot_mapping = _get_cross_slot_mapping(
|
||||
new_metadata.encoder_seq_lens_cpu,
|
||||
new_metadata.block_table_tensor,
|
||||
self.kv_cache_spec,
|
||||
self.device,
|
||||
)
|
||||
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
return attn_metadata
|
||||
|
||||
# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
|
||||
# `CrossAttentionBuilder` instead of the one computed by `BlockTable`
|
||||
# (gpu_model_runner)
|
||||
class CrossAttentionImpl(underlying_impl): # type: ignore[valid-type,misc]
|
||||
def forward(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
output: torch.Tensor | None = None,
|
||||
output_scale: torch.Tensor | None = None,
|
||||
output_block_scale: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
not underlying_attn_backend.forward_includes_kv_cache_update
|
||||
and attn_metadata is not None
|
||||
and layer.kv_sharing_target_layer_name is None
|
||||
and key is not None
|
||||
and value is not None
|
||||
):
|
||||
self.do_kv_cache_update(
|
||||
layer, key, value, kv_cache, attn_metadata.slot_mapping
|
||||
)
|
||||
|
||||
return super().forward(
|
||||
layer,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
kv_cache,
|
||||
attn_metadata,
|
||||
output,
|
||||
output_scale,
|
||||
output_block_scale,
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend_with_overrides(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
overrides={
|
||||
"get_builder_cls": lambda: CrossAttentionBuilder,
|
||||
"get_impl_cls": lambda: CrossAttentionImpl,
|
||||
"forward_includes_kv_cache_update": True,
|
||||
},
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
"""
|
||||
Cross-attention for encoder-decoder models.
|
||||
Handles attention between decoder queries and encoder keys/values.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: CacheConfig | None = None,
|
||||
attn_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_DECODER, (
|
||||
"CrossAttention only supports AttentionType.ENCODER_DECODER"
|
||||
)
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
)
|
||||
attn_backend = create_cross_attention_backend(underlying_attn_backend)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return CrossAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
101
vllm/model_executor/layers/attention/encoder_only_attention.py
Normal file
101
vllm/model_executor/layers/attention/encoder_only_attention.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_encoder_only_attention_backend(
|
||||
underlying_attn_backend: AttentionBackend,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "EncoderOnlyAttention_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class EncoderOnlyAttentionBuilder(underlying_builder): # type: ignore
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
new_common_attn_metadata = copy(common_attn_metadata)
|
||||
new_common_attn_metadata.causal = False
|
||||
return super().build(
|
||||
common_prefix_len, new_common_attn_metadata, fast_build
|
||||
)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=EncoderOnlyAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
class EncoderOnlyAttention(Attention):
|
||||
"""
|
||||
Encoder attention is a special case that doesn't need a KV Cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
cache_config: CacheConfig | None = None,
|
||||
attn_type: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
)
|
||||
|
||||
attn_backend = create_encoder_only_attention_backend(underlying_attn_backend)
|
||||
|
||||
if attn_type is not None:
|
||||
assert attn_type == AttentionType.ENCODER_ONLY, (
|
||||
"EncoderOnlyAttention only supports AttentionType.ENCODER_ONLY"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Does not need KV cache
|
||||
return None
|
||||
60
vllm/model_executor/layers/attention/kv_transfer_utils.py
Normal file
60
vllm/model_executor/layers/attention/kv_transfer_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group,
|
||||
)
|
||||
|
||||
|
||||
def maybe_transfer_kv_layer(func: Callable) -> Callable:
|
||||
"""Decorator that handles KV layer transfer prior and after execution of
|
||||
an attention layer, if enabled. Otherwise, the wrapper is a no-op.
|
||||
|
||||
On entry: waits for the KV layer from the connector.
|
||||
On exit: saves the KV layer to the connector.
|
||||
"""
|
||||
# Import at runtime to avoid circular dependency
|
||||
from vllm.model_executor.layers.attention.attention import get_attention_context
|
||||
|
||||
# Inspect the signature ONCE when the decorator is applied.
|
||||
sig = inspect.signature(func)
|
||||
param_names = list(sig.parameters.keys())
|
||||
|
||||
# Find the index of 'layer_name' parameter.
|
||||
try:
|
||||
layer_name_index = param_names.index("layer_name")
|
||||
except ValueError as e:
|
||||
raise TypeError(
|
||||
f"Function {func.__name__} must have a 'layer_name' parameter"
|
||||
) from e
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
layer_name: str = args[layer_name_index]
|
||||
|
||||
# Extract attention context (metadata, layer, kv_cache, layer_slot_mapping)
|
||||
attn_metadata, _, kv_cache, _ = get_attention_context(layer_name)
|
||||
connector = get_kv_transfer_group()
|
||||
if attn_metadata is None or not connector.has_connector_metadata():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
# Wait for KV layer on entry
|
||||
connector.wait_for_layer_load(layer_name)
|
||||
|
||||
# Execute the function
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
# Save KV cache layer on exit
|
||||
connector.save_kv_layer(layer_name, kv_cache, attn_metadata)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
3006
vllm/model_executor/layers/attention/mla_attention.py
Normal file
3006
vllm/model_executor/layers/attention/mla_attention.py
Normal file
File diff suppressed because it is too large
Load Diff
262
vllm/model_executor/layers/attention/mm_encoder_attention.py
Normal file
262
vllm/model_executor/layers/attention/mm_encoder_attention.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
vit_torch_sdpa_wrapper,
|
||||
vit_triton_attn_wrapper,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# --8<-- [start:mm_encoder_attn]
|
||||
@CustomOp.register("mm_encoder_attn")
|
||||
class MMEncoderAttention(CustomOp):
|
||||
"""Multi-headed attention without any cache, used for multimodal encoder."""
|
||||
|
||||
# --8<-- [end:mm_encoder_attn]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float | None = None,
|
||||
num_kv_heads: int | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_heads: number of attention heads per partition.
|
||||
head_size: hidden_size per attention head.
|
||||
scale: scale factor.
|
||||
num_kv_heads: number of kv heads.
|
||||
prefix: This has no effect, it is only here to make it easier to
|
||||
swap between Attention and MultiHeadAttention
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.layer_name = prefix
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0, (
|
||||
f"num_heads ({self.num_heads}) is not "
|
||||
f"divisible by num_kv_heads ({self.num_kv_heads})"
|
||||
)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
# Get device-specific vision attention backend.
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
self._fa_version = (
|
||||
get_flash_attn_version() if self.is_flash_attn_backend else None
|
||||
)
|
||||
|
||||
logger.info_once(f"Using {self.attn_backend} for MMEncoderAttention.")
|
||||
|
||||
@classmethod
|
||||
def enabled(cls) -> bool:
|
||||
return True
|
||||
|
||||
def view_qkv_to_4d(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
bsz: int,
|
||||
q_len: int,
|
||||
kv_len: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reshape query, key, value to 4D tensors:
|
||||
(batch_size, seq_len, num_heads, head_size)
|
||||
"""
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
return query, key, value
|
||||
|
||||
def _forward_sdpa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_torch_sdpa_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
enable_gqa=self.num_heads > self.num_kv_heads,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_fa(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||
cu_seqlens is None and max_seqlen is None
|
||||
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_flash_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
is_rocm_aiter=(self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA),
|
||||
fa_version=self._fa_version,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def _forward_triton(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
assert (cu_seqlens is not None and max_seqlen is not None) or (
|
||||
cu_seqlens is None and max_seqlen is None
|
||||
), "cu_seqlens and max_seqlen should be both set or both None."
|
||||
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
is_reshaped = query.dim() != 4
|
||||
|
||||
query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
|
||||
|
||||
output = vit_triton_attn_wrapper(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
batch_size=bsz,
|
||||
scale=self.scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if is_reshaped:
|
||||
output = output.reshape(bsz, q_len, -1)
|
||||
return output
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
if self.is_flash_attn_backend:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for CUDA: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
|
||||
def forward_xpu(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
|
||||
) -> torch.Tensor:
|
||||
if self.attn_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
return self._forward_triton(query, key, value, cu_seqlens, max_seqlen)
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
return self._forward_sdpa(query, key, value, cu_seqlens)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported multi-modal encoder attention backend for XPU: "
|
||||
f"{self.attn_backend}."
|
||||
)
|
||||
252
vllm/model_executor/layers/attention/static_sink_attention.py
Normal file
252
vllm/model_executor/layers/attention/static_sink_attention.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import functools
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionType,
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_reshape_and_cache_flash import (
|
||||
triton_reshape_and_cache_flash_diffkv,
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
AttentionSpec,
|
||||
KVCacheSpec,
|
||||
SinkFullAttentionSpec,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def create_static_sink_attention_backend(
|
||||
underlying_attn_backend: type[AttentionBackend],
|
||||
sink_len: int = 0,
|
||||
) -> type[AttentionBackend]:
|
||||
prefix = "StaticSink_"
|
||||
underlying_builder = underlying_attn_backend.get_builder_cls()
|
||||
|
||||
class StaticSinkAttentionBuilder(underlying_builder): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
kv_cache_spec: AttentionSpec,
|
||||
layer_names: list[str],
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device)
|
||||
model_config = vllm_config.model_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.sink_len = sink_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_sink_blocks = self.sink_len // vllm_config.cache_config.block_size
|
||||
self.max_num_blocks = cdiv(
|
||||
model_config.max_model_len, vllm_config.cache_config.block_size
|
||||
)
|
||||
self.block_table_with_sink = torch.zeros(
|
||||
(
|
||||
scheduler_config.max_num_seqs,
|
||||
self.max_num_blocks + self.num_sink_blocks,
|
||||
),
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.block_table_with_sink[:, : self.num_sink_blocks] = torch.arange(
|
||||
1,
|
||||
self.num_sink_blocks + 1,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AttentionMetadata:
|
||||
common_attn_metadata.seq_lens[:] = (
|
||||
common_attn_metadata.seq_lens + self.sink_len
|
||||
)
|
||||
common_attn_metadata.seq_lens[
|
||||
common_attn_metadata.seq_lens == self.sink_len
|
||||
] = 0
|
||||
common_attn_metadata.max_seq_len = (
|
||||
common_attn_metadata.max_seq_len + self.sink_len
|
||||
)
|
||||
max_num_blocks = cdiv(common_attn_metadata.max_seq_len, self.block_size)
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
self.block_table_with_sink[
|
||||
:num_reqs, self.num_sink_blocks : self.num_sink_blocks + max_num_blocks
|
||||
] = common_attn_metadata.block_table_tensor[:, :max_num_blocks]
|
||||
common_attn_metadata.block_table_tensor = self.block_table_with_sink[
|
||||
:num_reqs
|
||||
]
|
||||
|
||||
return super().build(common_prefix_len, common_attn_metadata, fast_build)
|
||||
|
||||
attn_backend = subclass_attention_backend(
|
||||
name_prefix=prefix,
|
||||
attention_backend_cls=underlying_attn_backend,
|
||||
builder_cls=StaticSinkAttentionBuilder,
|
||||
)
|
||||
|
||||
return attn_backend
|
||||
|
||||
|
||||
@CustomOp.register("static_sink_attention")
|
||||
class StaticSinkAttention(Attention, CustomOp):
|
||||
"""
|
||||
Attention with static sink tokens
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
sink_len: int,
|
||||
attn_backend: type[AttentionBackend] | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
dtype = torch.get_default_dtype()
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
|
||||
if attn_backend is not None:
|
||||
underlying_attn_backend = attn_backend
|
||||
else:
|
||||
underlying_attn_backend = get_attn_backend(
|
||||
head_size, dtype, kv_cache_dtype, block_size
|
||||
)
|
||||
attn_backend = create_static_sink_attention_backend(
|
||||
underlying_attn_backend, # type: ignore[arg-type]
|
||||
sink_len=sink_len,
|
||||
)
|
||||
Attention.__init__(
|
||||
self=self,
|
||||
num_heads=num_heads,
|
||||
head_size=head_size,
|
||||
scale=scale,
|
||||
cache_config=cache_config,
|
||||
attn_backend=attn_backend,
|
||||
**kwargs,
|
||||
)
|
||||
CustomOp.__init__(self)
|
||||
|
||||
self.sink_len = sink_len
|
||||
self.block_size = block_size
|
||||
self.sink_populated = False
|
||||
self.sink_key = None
|
||||
self.sink_value = None
|
||||
|
||||
def update_sink_kv(self, sink_key, sink_value) -> None:
|
||||
self.sink_key = sink_key
|
||||
self.sink_value = sink_value
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert self.sink_key is not None and self.sink_value is not None, (
|
||||
"sink_key and sink_value have not been prepared"
|
||||
)
|
||||
if not self.sink_populated:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
torch.ops.vllm.maybe_populate_sink(self_kv_cache, self.layer_name)
|
||||
|
||||
return super().forward(query, key, value, output_shape)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
output_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.forward_native(query, key, value, output_shape)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._forward_method(*args, **kwargs)
|
||||
|
||||
def populate_sink_kv(self, self_kv_cache):
|
||||
sink_kv_slot_mapping = torch.arange(
|
||||
self.block_size,
|
||||
self.sink_len + self.block_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=torch.long,
|
||||
)
|
||||
triton_reshape_and_cache_flash_diffkv(
|
||||
self.sink_key,
|
||||
self.sink_value,
|
||||
self_kv_cache,
|
||||
sink_kv_slot_mapping,
|
||||
self.kv_cache_dtype,
|
||||
self._k_scale,
|
||||
self._v_scale,
|
||||
)
|
||||
# We only populate the sink_key and sink_value once
|
||||
self.sink_populated = True
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
|
||||
return SinkFullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
head_size_v=self.head_size_v,
|
||||
sink_len=self.sink_len,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
def maybe_populate_sink(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if self.sink_populated or self_kv_cache.numel() == 0:
|
||||
return
|
||||
self.populate_sink_kv(self_kv_cache)
|
||||
|
||||
|
||||
def maybe_populate_sink_fake(
|
||||
self_kv_cache: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="maybe_populate_sink",
|
||||
op_func=maybe_populate_sink,
|
||||
mutates_args=["self_kv_cache"],
|
||||
fake_impl=maybe_populate_sink_fake,
|
||||
)
|
||||
34
vllm/model_executor/layers/attention_layer_base.py
Normal file
34
vllm/model_executor/layers/attention_layer_base.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Base class for attention-like layers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionBackend, AttentionImpl
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
|
||||
class AttentionLayerBase(ABC):
|
||||
"""
|
||||
Base class for attention-like layers (Attention, Mamba, etc.)
|
||||
that support the v1 engine.
|
||||
|
||||
This provides a common interface for getting attention backends
|
||||
from different layer types.
|
||||
"""
|
||||
|
||||
impl: "AttentionImpl"
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
"""Get the attention backend class for this layer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
"""
|
||||
Get the KV cache spec for this layer.
|
||||
May be None if the layer does not need KV cache.
|
||||
"""
|
||||
pass
|
||||
1069
vllm/model_executor/layers/batch_invariant.py
Normal file
1069
vllm/model_executor/layers/batch_invariant.py
Normal file
File diff suppressed because it is too large
Load Diff
262
vllm/model_executor/layers/conv.py
Normal file
262
vllm/model_executor/layers/conv.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Conv Layer Class."""
|
||||
|
||||
import math
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.utils.torch_utils import is_torch_equal
|
||||
|
||||
|
||||
class ConvLayerBase(CustomOp):
|
||||
"""Conv layer base class."""
|
||||
|
||||
num_dim: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int | tuple[int, ...],
|
||||
stride: int | tuple[int, ...] = 1,
|
||||
padding: int | tuple[int, ...] | Literal["same", "valid"] = 0,
|
||||
dilation: int | tuple[int, ...] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
|
||||
*,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
valid_padding_strings = {"same", "valid"}
|
||||
if isinstance(padding, str) and padding not in valid_padding_strings:
|
||||
raise ValueError(
|
||||
f"Invalid padding string '{padding}'. "
|
||||
f"Expected one of {valid_padding_strings}."
|
||||
)
|
||||
|
||||
if padding == "same":
|
||||
padding = (
|
||||
kernel_size // 2
|
||||
if isinstance(kernel_size, int)
|
||||
else tuple(k // 2 for k in kernel_size)
|
||||
)
|
||||
elif padding == "valid":
|
||||
padding = 0
|
||||
|
||||
kernel_size = (
|
||||
(kernel_size,) * self.num_dim
|
||||
if isinstance(kernel_size, int)
|
||||
else kernel_size
|
||||
)
|
||||
stride = (stride,) * self.num_dim if isinstance(stride, int) else stride
|
||||
padding = (padding,) * self.num_dim if isinstance(padding, int) else padding
|
||||
dilation = (dilation,) * self.num_dim if isinstance(dilation, int) else dilation
|
||||
|
||||
if padding == "same" and any(s != 1 for s in stride):
|
||||
raise ValueError("padding='same' is not supported for strided convolutions")
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
self.enable_linear = (
|
||||
(self.kernel_size == self.stride)
|
||||
and not any(self.padding)
|
||||
and self.groups == 1
|
||||
)
|
||||
self.input_size = in_channels * math.prod(self.kernel_size)
|
||||
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(
|
||||
out_channels,
|
||||
in_channels // groups,
|
||||
*kernel_size,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.out_channels, dtype=params_dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"in_channels={self.in_channels}, "
|
||||
s += f"out_channels={self.out_channels}, "
|
||||
s += f"kernel_size={self.kernel_size}, "
|
||||
s += f"stride={self.stride}, "
|
||||
s += f"padding={self.padding}, "
|
||||
s += f"bias={self.bias is not None}"
|
||||
return s
|
||||
|
||||
|
||||
# --8<-- [start:conv2d]
|
||||
@CustomOp.register("conv2d")
|
||||
class Conv2dLayer(ConvLayerBase):
|
||||
"""Conv layer with Conv2d."""
|
||||
|
||||
# --8<-- [end:conv2d]
|
||||
|
||||
num_dim = 2
|
||||
|
||||
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 4
|
||||
B, C, H, W = x.shape
|
||||
K1, K2 = self.kernel_size
|
||||
H, W = H // K1, W // K2
|
||||
x = x.unfold(2, K1, K1).unfold(3, K2, K2)
|
||||
x = x.permute(0, 2, 3, 1, 4, 5).reshape(-1, self.input_size)
|
||||
x = F.linear(
|
||||
x,
|
||||
self.weight.view(self.out_channels, self.input_size),
|
||||
self.bias,
|
||||
)
|
||||
x = x.view(B, H, W, self.out_channels).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 4
|
||||
x = F.conv2d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Expected input shape: (batch_size, in_channels, height, width)"""
|
||||
assert x.dim() == 4
|
||||
if self.enable_linear:
|
||||
return self._forward_mulmat(x)
|
||||
else:
|
||||
return self._forward_conv(x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# By default, we use CUDNN's convolution ops with optimization.
|
||||
return self._forward_conv(x)
|
||||
|
||||
|
||||
class CausalConv2dLayer(Conv2dLayer):
|
||||
"""
|
||||
A causal version of nn.Conv2d where each location in the 2D matrix would
|
||||
have no access to locations on its right or down
|
||||
All arguments are the same as nn.Conv2d except padding which should be
|
||||
set as None
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int,
|
||||
padding: int = 0,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = "zeros",
|
||||
*,
|
||||
params_dtype: torch.dtype | None = None,
|
||||
) -> None:
|
||||
if padding is not None:
|
||||
raise ValueError(
|
||||
"Argument padding should be set to None for CausalConv2dLayer."
|
||||
)
|
||||
self._left_padding: int = kernel_size - 1
|
||||
self._right_padding: int = stride - 1
|
||||
padding = 0
|
||||
|
||||
super().__init__(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
padding_mode,
|
||||
params_dtype=params_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
x = F.pad(x, pad=(self._left_padding, self._right_padding, 0, 0))
|
||||
x = super().forward(x)
|
||||
return x
|
||||
|
||||
|
||||
# --8<-- [start:conv3d]
|
||||
@CustomOp.register("conv3d")
|
||||
class Conv3dLayer(ConvLayerBase):
|
||||
"""Conv layer with Conv3d."""
|
||||
|
||||
# --8<-- [end:conv3d]
|
||||
|
||||
num_dim = 3
|
||||
|
||||
def _forward_mulmat(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 5
|
||||
B, C, T, H, W = x.shape
|
||||
K1, K2, K3 = self.kernel_size
|
||||
T, H, W = T // K1, H // K2, W // K3
|
||||
x = x.unfold(2, K1, K1).unfold(3, K2, K2).unfold(4, K3, K3)
|
||||
x = x.permute(0, 2, 3, 4, 1, 5, 6, 7).reshape(-1, self.input_size)
|
||||
x = F.linear(
|
||||
x,
|
||||
self.weight.view(self.out_channels, self.input_size),
|
||||
self.bias,
|
||||
)
|
||||
x = x.view(B, T, H, W, self.out_channels).permute(0, 4, 1, 2, 3)
|
||||
return x
|
||||
|
||||
def _forward_conv(self, x: torch.Tensor) -> torch.Tensor:
|
||||
assert x.dim() == 5
|
||||
x = F.conv3d(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
)
|
||||
return x
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Expected input shape: (batch_size, in_channels, time, height, width)"""
|
||||
if self.enable_linear:
|
||||
return self._forward_mulmat(x)
|
||||
else:
|
||||
return self._forward_conv(x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a
|
||||
# significant performance regression.
|
||||
# See: https://github.com/vllm-project/vllm/issues/27406
|
||||
# and https://github.com/pytorch/pytorch/issues/166122
|
||||
# By default, we use CUDNN's convolution ops with optimization.
|
||||
if self.enable_linear and (is_torch_equal("2.9.0") or is_torch_equal("2.9.1")):
|
||||
return self._forward_mulmat(x)
|
||||
return self._forward_conv(x)
|
||||
8
vllm/model_executor/layers/fla/__init__.py
Normal file
8
vllm/model_executor/layers/fla/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
17
vllm/model_executor/layers/fla/ops/__init__.py
Normal file
17
vllm/model_executor/layers/fla/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .layernorm_guard import RMSNormGated
|
||||
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
]
|
||||
219
vllm/model_executor/layers/fla/ops/chunk.py
Normal file
219
vllm/model_executor/layers/fla/ops/chunk.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(
|
||||
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
||||
)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type="cuda")
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
Queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
Keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
Values of shape `[B, T, H, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) Gating tensor (in log space!) of shape `[B, T, H]`.
|
||||
beta (torch.Tensor):
|
||||
Betas of shape `[B, T, H]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, V, K]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, (
|
||||
"ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
)
|
||||
assert len(beta.shape) == 3, "beta must be of shape [B, T, H]."
|
||||
if q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2,
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
344
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Normal file
344
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp
|
||||
from .utils import use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4, 8, 16]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"USE_GK": lambda args: args["gk"] is not None,
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"STORE_FINAL_STATE": lambda args: args["ht"] is not None,
|
||||
"SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4]
|
||||
for num_stages in [2, 3, 4]
|
||||
for BV in [32, 64]
|
||||
],
|
||||
key=["H", "K", "V", "BT"],
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
gk,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_GK: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
if IS_VARLEN:
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BV, BK]
|
||||
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += ((boh * H + i_h) * V * K).to(tl.int64)
|
||||
v += ((bos * H + i_h) * V).to(tl.int64)
|
||||
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
|
||||
w += ((bos * H + i_h) * K).to(tl.int64)
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += ((bos * H + i_h) * V).to(tl.int64)
|
||||
stride_v = H * V
|
||||
stride_h = H * V * K
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * V * K
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * V * K
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(
|
||||
h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
)
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(
|
||||
h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
)
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(
|
||||
h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
)
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(
|
||||
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(
|
||||
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
|
||||
)
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v = tl.make_block_ptr(
|
||||
v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(
|
||||
g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
|
||||
b_g_last = exp(b_g_last)
|
||||
b_h1 *= b_g_last
|
||||
if K > 64:
|
||||
b_h2 *= b_g_last
|
||||
if K > 128:
|
||||
b_h3 *= b_g_last
|
||||
if K > 192:
|
||||
b_h4 *= b_g_last
|
||||
|
||||
if USE_GK:
|
||||
o_k1 = tl.arange(0, 64)
|
||||
b_gk_last1 = tl.load(
|
||||
gk + (bos + last_idx) * H * K + i_h * K + o_k1,
|
||||
mask=(o_k1 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h1 *= exp(b_gk_last1)[None, :]
|
||||
if K > 64:
|
||||
o_k2 = 64 + o_k1
|
||||
b_gk_last2 = tl.load(
|
||||
gk + (bos + last_idx) * H * K + i_h * K + o_k2,
|
||||
mask=(o_k2 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h2 *= exp(b_gk_last2)[None, :]
|
||||
if K > 128:
|
||||
o_k3 = 128 + o_k1
|
||||
b_gk_last3 = tl.load(
|
||||
gk + (bos + last_idx) * H * K + i_h * K + o_k3,
|
||||
mask=(o_k3 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h3 *= exp(b_gk_last3)[None, :]
|
||||
if K > 192:
|
||||
o_k4 = 192 + o_k1
|
||||
b_gk_last4 = tl.load(
|
||||
gk + (bos + last_idx) * H * K + i_h * K + o_k4,
|
||||
mask=(o_k4 < K),
|
||||
other=0.0,
|
||||
)
|
||||
b_h4 *= exp(b_gk_last4)[None, :]
|
||||
b_v = b_v.to(k.dtype.element_ty)
|
||||
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.trans(tl.dot(b_k, b_v))
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.trans(tl.dot(b_k, b_v))
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.trans(tl.dot(b_k, b_v))
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.trans(tl.dot(b_k, b_v))
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(
|
||||
ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
|
||||
)
|
||||
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
gk: torch.Tensor | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, V, K)
|
||||
final_state = (
|
||||
k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None
|
||||
)
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta["BV"]), N * H)
|
||||
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
gk=gk,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
)
|
||||
return h, v_new, final_state
|
||||
183
vllm/model_executor/layers/fla/ops/chunk_o.py
Normal file
183
vllm/model_executor/layers/fla/ops/chunk_o.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages)
|
||||
for BK in BKV_LIST
|
||||
for BV in BKV_LIST
|
||||
for num_warps in NUM_WARPS
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=["H", "K", "V", "BT"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * V * K
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(
|
||||
q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)
|
||||
)
|
||||
p_k = tl.make_block_ptr(
|
||||
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
|
||||
)
|
||||
p_h = tl.make_block_ptr(
|
||||
h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
|
||||
)
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BV, BK]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, tl.trans(b_h))
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_o = b_o * exp(b_g)[:, None]
|
||||
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(
|
||||
v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta["BV"]), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
)
|
||||
return o
|
||||
154
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Normal file
154
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_G": lambda args: args["g"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages)
|
||||
for BK in [32, 64, 128]
|
||||
for num_warps in [2, 4, 8]
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=["H", "K", "BT", "IS_VARLEN"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta,
|
||||
g,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K),
|
||||
(Hg * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
b_g = tl.load(p_g, boundary_check=(0,))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * exp(b_g_diff)
|
||||
|
||||
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
g: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
B, T, Hg, K = k.shape
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
g=g,
|
||||
beta=beta,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
)
|
||||
return A
|
||||
280
vllm/model_executor/layers/fla/ops/cumsum.py
Normal file
280
vllm/model_executor/layers/fla/ops/cumsum.py
Normal file
@@ -0,0 +1,280 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import check_shared_mem, input_guard
|
||||
|
||||
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(
|
||||
configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]],
|
||||
key=["B", "H", "BT", "IS_VARLEN", "REVERSE"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(
|
||||
s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
# [BT]
|
||||
b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BS": BS}, num_warps=num_warps)
|
||||
for BS in BS_LIST
|
||||
for num_warps in [2, 4, 8]
|
||||
],
|
||||
key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def chunk_local_cumsum_vector_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
if REVERSE:
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)
|
||||
else:
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(
|
||||
s + (bos * H + i_h * T) * S,
|
||||
(T, S),
|
||||
(S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o + (bos * H + i_h * T) * S,
|
||||
(T, S),
|
||||
(S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
else:
|
||||
p_s = tl.make_block_ptr(
|
||||
s + (bos * H + i_h) * S,
|
||||
(T, S),
|
||||
(H * S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
p_o = tl.make_block_ptr(
|
||||
o + (bos * H + i_h) * S,
|
||||
(T, S),
|
||||
(H * S, 1),
|
||||
(i_t * BT, i_s * BS),
|
||||
(BT, BS),
|
||||
(1, 0),
|
||||
)
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
||||
"chunk_size must be a power of 2"
|
||||
)
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
chunk_local_cumsum_scalar_kernel[grid](
|
||||
g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
||||
"chunk_size must be a power of 2"
|
||||
)
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H)
|
||||
|
||||
# keep cumulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_local_cumsum_vector_kernel[grid](
|
||||
g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
S=S,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse,
|
||||
)
|
||||
return g
|
||||
|
||||
|
||||
@input_guard
|
||||
def chunk_local_cumsum(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if not head_first and g.shape[1] < g.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2,
|
||||
)
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[0] == 1, (
|
||||
"Only batch size 1 is supported when cu_seqlens are provided"
|
||||
)
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(
|
||||
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
|
||||
)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(
|
||||
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported input shape {g.shape}. "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise"
|
||||
)
|
||||
393
vllm/model_executor/layers/fla/ops/fused_recurrent.py
Normal file
393
vllm/model_executor/layers/fla/ops/fused_recurrent.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .op import exp
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"USE_INITIAL_STATE": lambda args: args["h0"] is not None,
|
||||
"IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
|
||||
"IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
|
||||
"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["N", "T"])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.int64, # num of sequences
|
||||
T: tl.int64, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_KDA: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int64),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int64),
|
||||
)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv
|
||||
|
||||
if not IS_KDA:
|
||||
p_g = g + bos * HV + i_hv
|
||||
else:
|
||||
p_gk = g + (bos * HV + i_hv) * K + o_k
|
||||
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_v[:, None] & mask_k[None, :]
|
||||
|
||||
b_h = tl.zeros([BV, BK], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(
|
||||
tl.int64
|
||||
)
|
||||
# Skip if state index is invalid (PAD_SLOT_ID = -1)
|
||||
if state_idx < 0:
|
||||
return
|
||||
p_h0 = h0 + state_idx * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * V * K
|
||||
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BV, BK]
|
||||
if not IS_KDA:
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
b_h *= exp(b_g)
|
||||
else:
|
||||
b_gk = tl.load(p_gk).to(tl.float32)
|
||||
b_h *= exp(b_gk[None, :])
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[None, :], 1)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BV, BK]
|
||||
b_h += b_v[:, None] * b_k[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[None, :], 1)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
# Load state index and check for PAD_SLOT_ID (-1)
|
||||
final_state_idx = tl.load(
|
||||
ssm_state_indices + i_n * stride_indices_seq + i_t
|
||||
).to(tl.int64)
|
||||
# Only store if state index is valid (not PAD_SLOT_ID)
|
||||
if final_state_idx >= 0:
|
||||
p_ht = ht + final_state_idx * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
if not IS_KDA:
|
||||
p_g += HV
|
||||
else:
|
||||
p_gk += HV * K
|
||||
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
if ssm_state_indices is None:
|
||||
stride_indices_seq, stride_indices_tok = 1, 1
|
||||
elif ssm_state_indices.ndim == 1:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
||||
else:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
scale=scale,
|
||||
N=N,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
stride_indices_tok=stride_indices_tok,
|
||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
INPLACE_FINAL_STATE=inplace_final_state,
|
||||
IS_KDA=False,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
||||
q=q.contiguous(),
|
||||
k=k.contiguous(),
|
||||
v=v.contiguous(),
|
||||
g=g.contiguous(),
|
||||
beta=beta.contiguous(),
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
inplace_final_state=inplace_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
return o, final_state
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: torch.LongTensor | None = None,
|
||||
ssm_state_indices: torch.Tensor | None = None,
|
||||
num_accepted_tokens: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, V, K]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, V, K]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, V, K, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
41
vllm/model_executor/layers/fla/ops/index.py
Normal file
41
vllm/model_executor/layers/fla/ops/index.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
indices = torch.cat(
|
||||
[
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
]
|
||||
)
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(
|
||||
cu_seqlens: torch.LongTensor, chunk_size: int
|
||||
) -> torch.LongTensor:
|
||||
return torch.cat(
|
||||
[cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
|
||||
).cumsum(-1)
|
||||
1351
vllm/model_executor/layers/fla/ops/kda.py
Normal file
1351
vllm/model_executor/layers/fla/ops/kda.py
Normal file
File diff suppressed because it is too large
Load Diff
146
vllm/model_executor/layers/fla/ops/l2norm.py
Normal file
146
vllm/model_executor/layers/fla/ops/l2norm.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
y,
|
||||
D,
|
||||
BD: tl.constexpr,
|
||||
eps,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
x += i_t * D
|
||||
y += i_t * D
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BD)
|
||||
mask = cols < D
|
||||
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=0)
|
||||
b_rstd = 1 / tl.sqrt(b_var + eps)
|
||||
# tl.store(Rstd + i_t, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
b_y = b_x * b_rstd
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BT": BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16]
|
||||
for BT in BT_LIST
|
||||
],
|
||||
key=["D"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB,
|
||||
T,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=1)
|
||||
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
||||
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * MBLOCK
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(
|
||||
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta["BT"]),)
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T,)](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
388
vllm/model_executor/layers/fla/ops/layernorm_guard.py
Normal file
388
vllm/model_executor/layers/fla/ops/layernorm_guard.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Tri Dao
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
# ruff: noqa: E501
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.math_utils import cdiv, next_power_of_2
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True,
|
||||
):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd * weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N: tl.constexpr, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
ROWS_PER_BLOCK: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the starting row of X and Y it should compute.
|
||||
row_start = tl.program_id(0) * ROWS_PER_BLOCK
|
||||
group = tl.program_id(1)
|
||||
|
||||
# Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N]
|
||||
rows = row_start + tl.arange(0, ROWS_PER_BLOCK)
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
|
||||
# Compute offsets for 2D tile
|
||||
row_offsets = rows[:, None] * stride_x_row
|
||||
col_offsets = cols[None, :] + group * N
|
||||
|
||||
# Base pointers
|
||||
X_base = X + row_offsets + col_offsets
|
||||
Y_base = Y + rows[:, None] * stride_y_row + col_offsets
|
||||
|
||||
# Create mask for valid rows and columns
|
||||
row_mask = rows[:, None] < M
|
||||
col_mask = cols[None, :] < N
|
||||
mask = row_mask & col_mask
|
||||
|
||||
# Load input data with 2D tile
|
||||
x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
|
||||
# Compute mean and variance per row (reduce along axis 1)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=1) / N # Shape: [ROWS_PER_BLOCK]
|
||||
# Store mean for each row
|
||||
mean_offsets = group * M + rows
|
||||
mean_mask = rows < M
|
||||
tl.store(Mean + mean_offsets, mean, mask=mean_mask)
|
||||
# Broadcast mean back to 2D for subtraction
|
||||
xbar = tl.where(mask, x - mean[:, None], 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
|
||||
else:
|
||||
xbar = tl.where(mask, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK]
|
||||
mean = 0.0 # Placeholder for RMS norm
|
||||
|
||||
rstd = tl.rsqrt(var + eps) # Shape: [ROWS_PER_BLOCK]
|
||||
|
||||
# Store rstd for each row
|
||||
rstd_offsets = group * M + rows
|
||||
rstd_mask = rows < M
|
||||
tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask)
|
||||
|
||||
# Load weights and biases (broadcast across rows)
|
||||
w_offsets = cols + group * N
|
||||
w_mask = cols < N
|
||||
w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Normalize and apply linear transformation
|
||||
if not IS_RMS_NORM:
|
||||
x_hat = (x - mean[:, None]) * rstd[:, None]
|
||||
else:
|
||||
x_hat = x * rstd[:, None]
|
||||
|
||||
y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :]
|
||||
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
Z_base = Z + rows[:, None] * stride_z_row + col_offsets
|
||||
z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
|
||||
# Write output
|
||||
tl.store(Y_base, y, mask=mask)
|
||||
|
||||
|
||||
def calc_rows_per_block(M: int, device: torch.device) -> int:
|
||||
sm_count = num_compute_units(device.index)
|
||||
rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count))
|
||||
rows_per_block = min(rows_per_block, 4)
|
||||
return rows_per_block
|
||||
|
||||
|
||||
def layer_norm_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float,
|
||||
z: torch.Tensor = None,
|
||||
out: torch.Tensor = None,
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
# Calculate rows per block based on SM count
|
||||
rows_per_block = calc_rows_per_block(M, x.device)
|
||||
# Update grid to use rows_per_block
|
||||
grid = (cdiv(M, rows_per_block), ngroups)
|
||||
layer_norm_fwd_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
ROWS_PER_BLOCK=rows_per_block,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, is_rms_norm
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: int | None = None,
|
||||
norm_before_gate: bool = True,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: int | None = None,
|
||||
norm_before_gate: bool = False,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
)
|
||||
60
vllm/model_executor/layers/fla/ops/op.py
Normal file
60
vllm/model_executor/layers/fla/ops/op.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
from .utils import is_gather_supported
|
||||
|
||||
if os.environ.get("FLA_USE_FAST_OPS", "0") == "1":
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
if not is_gather_supported:
|
||||
|
||||
@triton.jit
|
||||
def gather(src, index, axis, _builder=None):
|
||||
"""
|
||||
Gather operation that works when tl.gather is not supported.
|
||||
This is a fallback implementation that returns None.
|
||||
Just to make triton compiler happy.
|
||||
"""
|
||||
return None
|
||||
else:
|
||||
gather = tl.gather
|
||||
|
||||
if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
|
||||
# For Triton 3.3.x
|
||||
make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
|
||||
elif hasattr(triton.language, "make_tensor_descriptor"):
|
||||
# For Triton 3.4.x and later
|
||||
make_tensor_descriptor = triton.language.make_tensor_descriptor
|
||||
else:
|
||||
"""
|
||||
Fallback implementation when TMA is not supported.
|
||||
Returns None to indicate TMA descriptors are unavailable.
|
||||
Just make triton compiler happy.
|
||||
"""
|
||||
|
||||
@triton.jit
|
||||
def make_tensor_descriptor(
|
||||
base,
|
||||
shape,
|
||||
strides,
|
||||
block_shape,
|
||||
_builder=None,
|
||||
):
|
||||
return None
|
||||
556
vllm/model_executor/layers/fla/ops/solve_tril.py
Normal file
556
vllm/model_executor/layers/fla/ops/solve_tril.py
Normal file
@@ -0,0 +1,556 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import make_tensor_descriptor
|
||||
from .utils import input_guard, is_amd, is_tma_supported
|
||||
|
||||
FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
|
||||
ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"]
|
||||
assert FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS, (
|
||||
f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}"
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [1, 2, 4, 8]
|
||||
for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=["BT"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DOT_PRECISION: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_i = tl.arange(0, 16)
|
||||
m_A = o_i[:, None] > o_i[None, :]
|
||||
m_I = o_i[:, None] == o_i[None, :]
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ai = Ai + (bos * H + i_h) * 16
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
if not USE_TMA:
|
||||
p_A = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)
|
||||
)
|
||||
# [16, 16]
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||
desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16])
|
||||
b_A = desc.load([i_t * 16, offset]).to(tl.float32)
|
||||
b_A = -tl.where(m_A, b_A, 0)
|
||||
|
||||
for i in range(2, min(16, T - i_t * 16)):
|
||||
# [16]
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
b_A = tl.where((o_i == i)[:, None], b_a, b_A)
|
||||
b_A += m_I
|
||||
if not USE_TMA:
|
||||
p_Ai = tl.make_block_ptr(
|
||||
Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
tl.store(
|
||||
p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
else:
|
||||
desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [1, 2, 4, 8]
|
||||
for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=["H", "BT", "IS_VARLEN"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def merge_16x16_to_32x32_inverse_kernel(
|
||||
A,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DOT_PRECISION: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
m_A = o_i[:, None] > o_i[None, :]
|
||||
m_I = o_i[:, None] == o_i[None, :]
|
||||
A += (bos * H + i_h) * BT
|
||||
Ai += (bos * H + i_h) * BT
|
||||
|
||||
if not USE_TMA:
|
||||
p_A_11 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_22 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||
)
|
||||
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
|
||||
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
|
||||
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
|
||||
|
||||
# [16, 16]
|
||||
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
|
||||
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
|
||||
|
||||
for i in range(2, min(16, T - i_t * BT)):
|
||||
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
|
||||
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
|
||||
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
|
||||
for i in range(16 + 2, min(32, T - i_t * BT)):
|
||||
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
|
||||
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
|
||||
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
|
||||
|
||||
b_Ai_11 += m_I
|
||||
b_Ai_22 += m_I
|
||||
|
||||
if not USE_TMA:
|
||||
p_A_21 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
|
||||
|
||||
b_Ai_21 = -tl.dot(
|
||||
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
|
||||
b_Ai_11,
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
|
||||
if not USE_TMA:
|
||||
p_Ai_11 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_21 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_22 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_11,
|
||||
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_22,
|
||||
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_21,
|
||||
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
else:
|
||||
desc_o.store(
|
||||
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4, 8]
|
||||
for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=["H", "BT", "IS_VARLEN"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def merge_16x16_to_64x64_inverse_kernel(
|
||||
A,
|
||||
Ai,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
USE_TMA: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
DOT_PRECISION: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
m_A = o_i[:, None] > o_i[None, :]
|
||||
m_I = o_i[:, None] == o_i[None, :]
|
||||
A += (bos * H + i_h) * BT
|
||||
Ai += (bos * H + i_h) * BT
|
||||
|
||||
if not USE_TMA:
|
||||
p_A_11 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_22 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_33 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_44 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
|
||||
)
|
||||
b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
|
||||
desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
|
||||
b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
|
||||
b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
|
||||
b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
|
||||
b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
|
||||
|
||||
# [16, 16]
|
||||
b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
|
||||
b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
|
||||
b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
|
||||
b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
|
||||
|
||||
for i in range(2, min(16, T - i_t * BT)):
|
||||
b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
|
||||
b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
|
||||
b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
|
||||
for i in range(16 + 2, min(32, T - i_t * BT)):
|
||||
b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
|
||||
b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
|
||||
b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
|
||||
for i in range(32 + 2, min(48, T - i_t * BT)):
|
||||
b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
|
||||
b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
|
||||
b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
|
||||
for i in range(48 + 2, min(64, T - i_t * BT)):
|
||||
b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
|
||||
b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
|
||||
b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
|
||||
b_Ai_11 += m_I
|
||||
b_Ai_22 += m_I
|
||||
b_Ai_33 += m_I
|
||||
b_Ai_44 += m_I
|
||||
|
||||
if not USE_TMA:
|
||||
p_A_21 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_31 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_32 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_41 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_42 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_A_43 = tl.make_block_ptr(
|
||||
A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
|
||||
)
|
||||
b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
||||
else:
|
||||
b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
|
||||
b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
|
||||
b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
|
||||
b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
|
||||
b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
|
||||
b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
|
||||
|
||||
b_Ai_21 = -tl.dot(
|
||||
tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
|
||||
b_Ai_11,
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
b_Ai_32 = -tl.dot(
|
||||
tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
|
||||
b_Ai_22,
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
b_Ai_43 = -tl.dot(
|
||||
tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
|
||||
b_Ai_33,
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
|
||||
b_Ai_31 = -tl.dot(
|
||||
b_Ai_33,
|
||||
tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION)
|
||||
+ tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
b_Ai_42 = -tl.dot(
|
||||
b_Ai_44,
|
||||
tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION)
|
||||
+ tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
b_Ai_41 = -tl.dot(
|
||||
b_Ai_44,
|
||||
tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
|
||||
+ tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
|
||||
+ tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
|
||||
input_precision=DOT_PRECISION,
|
||||
)
|
||||
|
||||
if not USE_TMA:
|
||||
p_Ai_11 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_22 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_33 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_44 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_21 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_31 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_32 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_41 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_42 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)
|
||||
)
|
||||
p_Ai_43 = tl.make_block_ptr(
|
||||
Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_11,
|
||||
b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_22,
|
||||
b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_33,
|
||||
b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_44,
|
||||
b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_21,
|
||||
b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_31,
|
||||
b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_32,
|
||||
b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_41,
|
||||
b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_42,
|
||||
b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
tl.store(
|
||||
p_Ai_43,
|
||||
b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1),
|
||||
)
|
||||
else:
|
||||
desc_o.store(
|
||||
[i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
desc_o.store(
|
||||
[i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")
|
||||
)
|
||||
|
||||
|
||||
@input_guard
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the matrix I + A
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`.
|
||||
If `None`, the output dtype will be the same as the input dtype.
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
output_dtype = A.dtype if output_dtype is None else output_dtype
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
Ai = torch.zeros_like(A, dtype=output_dtype)
|
||||
if BT == 16:
|
||||
merge_fn = solve_tril_16x16_kernel
|
||||
elif BT == 32:
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel
|
||||
elif BT == 64:
|
||||
merge_fn = merge_16x16_to_64x64_inverse_kernel
|
||||
|
||||
merge_fn[NT, B * H](
|
||||
A=A,
|
||||
Ai=Ai,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
USE_TMA=is_tma_supported,
|
||||
DOT_PRECISION=FLA_TRIL_PRECISION,
|
||||
)
|
||||
return Ai
|
||||
194
vllm/model_executor/layers/fla/ops/utils.py
Normal file
194
vllm/model_executor/layers/fla/ops/utils.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
|
||||
|
||||
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
||||
|
||||
|
||||
def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: tuple[tuple | None, dict | None, Any] = []
|
||||
cache_size = 8
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache_entries, cache_size
|
||||
for i, entry in enumerate(cache_entries):
|
||||
last_args, last_kwargs, last_result = entry
|
||||
if (
|
||||
len(args) == len(last_args)
|
||||
and len(kwargs) == len(last_kwargs)
|
||||
and all(a is b for a, b in zip(args, last_args))
|
||||
and all(
|
||||
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
|
||||
)
|
||||
):
|
||||
cache_entries = (
|
||||
cache_entries[:i]
|
||||
+ cache_entries[i + 1 :]
|
||||
+ [(args, kwargs, last_result)]
|
||||
)
|
||||
return last_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if len(cache_entries) >= cache_size:
|
||||
cache_entries = cache_entries[1:]
|
||||
cache_entries.append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (
|
||||
i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args
|
||||
)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = torch.cuda.device(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except (RuntimeError, AttributeError):
|
||||
return "cpu"
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
|
||||
device = get_available_device()
|
||||
mapping = {
|
||||
"cuda": "nvidia",
|
||||
"hip": "amd",
|
||||
"xpu": "intel",
|
||||
}
|
||||
# return the mapped value, or the original if not found
|
||||
return mapping.get(device, device)
|
||||
|
||||
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = "cuda" if current_platform.is_cuda_alike() else get_available_device()
|
||||
device_torch_lib = getattr(torch, device, None)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = device_platform == "amd"
|
||||
is_intel = device_platform == "intel"
|
||||
is_nvidia = device_platform == "nvidia"
|
||||
is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
|
||||
is_nvidia_hopper = is_nvidia and (
|
||||
"NVIDIA H" in torch.cuda.get_device_name(0)
|
||||
or torch.cuda.get_device_capability()[0] >= 9
|
||||
)
|
||||
use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1"
|
||||
is_gather_supported = hasattr(triton.language, "gather")
|
||||
is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
|
||||
hasattr(triton.language, "_experimental_make_tensor_descriptor")
|
||||
or hasattr(triton.language, "make_tensor_descriptor")
|
||||
)
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)[
|
||||
"max_shared_mem"
|
||||
]
|
||||
for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
||||
158
vllm/model_executor/layers/fla/ops/wy_fast.py
Normal file
158
vllm/model_executor/layers/fla/ops/wy_fast.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4, 8]
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"],
|
||||
)
|
||||
@triton.jit(do_not_specialize=["T"])
|
||||
def recompute_w_u_fwd_kernel(
|
||||
k,
|
||||
v,
|
||||
beta,
|
||||
w,
|
||||
u,
|
||||
A,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = (
|
||||
tl.load(chunk_indices + i_t * 2).to(tl.int32),
|
||||
tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
|
||||
)
|
||||
bos, eos = (
|
||||
tl.load(cu_seqlens + i_n).to(tl.int32),
|
||||
tl.load(cu_seqlens + i_n + 1).to(tl.int32),
|
||||
)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(
|
||||
beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)
|
||||
)
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
|
||||
p_A = tl.make_block_ptr(
|
||||
A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)
|
||||
)
|
||||
b_beta = tl.load(p_beta, boundary_check=(0,))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(
|
||||
v + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
p_u = tl.make_block_ptr(
|
||||
u + (bos * H + i_h) * V,
|
||||
(T, V),
|
||||
(H * V, 1),
|
||||
(i_t * BT, i_v * BV),
|
||||
(BT, BV),
|
||||
(1, 0),
|
||||
)
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(
|
||||
k + (bos * Hg + i_h // (H // Hg)) * K,
|
||||
(T, K),
|
||||
(Hg * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
p_w = tl.make_block_ptr(
|
||||
w + (bos * H + i_h) * K,
|
||||
(T, K),
|
||||
(H * K, 1),
|
||||
(i_t * BT, i_k * BK),
|
||||
(BT, BK),
|
||||
(1, 0),
|
||||
)
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.LongTensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_w_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return w, u
|
||||
140
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
140
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.activation import (
|
||||
MoEActivation,
|
||||
activation_without_mul,
|
||||
apply_moe_activation,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
RoutingMethodType,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
|
||||
FusedMoEMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE,
|
||||
FusedMoeWeightScaleSupported,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat,
|
||||
FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_moe_router import (
|
||||
FusedMoERouter,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import (
|
||||
UnquantizedFusedMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.zero_expert_fused_moe import (
|
||||
ZeroExpertFusedMoE,
|
||||
)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_config(config):
|
||||
global _config
|
||||
old_config = _config
|
||||
_config = config
|
||||
yield
|
||||
_config = old_config
|
||||
|
||||
|
||||
def get_config() -> dict[str, Any] | None:
|
||||
return _config
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoERouter",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"MoEActivation",
|
||||
"UnquantizedFusedMoEMethod",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"RoutingMethodType",
|
||||
"SharedFusedMoE",
|
||||
"ZeroExpertFusedMoE",
|
||||
"activation_without_mul",
|
||||
"apply_moe_activation",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
|
||||
if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
CutlassExpertsW4A8Fp8,
|
||||
cutlass_moe_w4a8_fp8,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts,
|
||||
TritonWNA16Experts,
|
||||
fused_experts,
|
||||
get_config_file_name,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
AiterExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.fused_topk_router import (
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.router.grouped_topk_router import (
|
||||
GroupedTopk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import (
|
||||
XPUExperts,
|
||||
XPUExpertsFp8,
|
||||
)
|
||||
|
||||
__all__ += [
|
||||
"AiterExperts",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"GroupedTopk",
|
||||
"cutlass_moe_w4a8_fp8",
|
||||
"CutlassExpertsFp8",
|
||||
"CutlassBatchedExpertsFp8",
|
||||
"CutlassExpertsW4A8Fp8",
|
||||
"TritonExperts",
|
||||
"TritonWNA16Experts",
|
||||
"BatchedTritonExperts",
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"XPUExperts",
|
||||
"XPUExpertsFp8",
|
||||
]
|
||||
else:
|
||||
# Some model classes directly use the custom ops. Add placeholders
|
||||
# to avoid import errors.
|
||||
def _raise_exception(method: str):
|
||||
raise NotImplementedError(f"{method} is not implemented as lack of triton.")
|
||||
|
||||
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
141
vllm/model_executor/layers/fused_moe/activation.py
Normal file
141
vllm/model_executor/layers/fused_moe/activation.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""MoE activation function enum and utilities."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._custom_ops import silu_and_mul, gelu_and_mul, swigluoai_and_mul
|
||||
|
||||
|
||||
class MoEActivation(Enum):
|
||||
"""Activation functions for MoE layers."""
|
||||
|
||||
# Gated activations (gate * activation(up)) expect input of shape [..., 2*d]
|
||||
# and produce output of shape [..., d]
|
||||
SILU = "silu"
|
||||
GELU = "gelu"
|
||||
RELU2 = "relu2"
|
||||
SWIGLUOAI = "swigluoai"
|
||||
SWIGLUSTEP = "swiglustep"
|
||||
|
||||
# Non-gated activations (no mul with gate) expect input of shape [..., d]
|
||||
# and produce output of shape [..., d].
|
||||
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
|
||||
SILU_NO_MUL = "silu_no_mul"
|
||||
GELU_NO_MUL = "gelu_no_mul"
|
||||
RELU2_NO_MUL = "relu2_no_mul"
|
||||
|
||||
@property
|
||||
def is_gated(self) -> bool:
|
||||
"""Returns True if activation expects gate*activation(up) pattern.
|
||||
|
||||
Gated activations expect input tensor with 2x the output size,
|
||||
where the first half is the gate and second half is the up projection.
|
||||
"""
|
||||
return not self.value.endswith("_no_mul")
|
||||
|
||||
@property
|
||||
def custom_op_name(self) -> str:
|
||||
"""Maps to the CustomOp name of activations
|
||||
in vllm/model_executor/layers/activation.py."""
|
||||
return _CUSTOM_OP_NAMES[self]
|
||||
|
||||
def without_mul(self) -> "MoEActivation":
|
||||
"""Get the non-gated variant of this activation.
|
||||
|
||||
For activations that have a _no_mul variant, returns that variant.
|
||||
For activations without a _no_mul variant (or already _no_mul),
|
||||
returns self.
|
||||
"""
|
||||
return _WITHOUT_MUL.get(self, self)
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, s: str) -> "MoEActivation":
|
||||
"""Parse from string for backward compatibility."""
|
||||
for member in cls:
|
||||
if member.value == s:
|
||||
return member
|
||||
valid = [m.value for m in cls]
|
||||
raise ValueError(f"Unknown MoE activation: {s!r}. Valid activations: {valid}")
|
||||
|
||||
|
||||
# Module-level lookup tables used by MoEActivation functions.
|
||||
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
|
||||
MoEActivation.SILU: "silu_and_mul",
|
||||
MoEActivation.GELU: "gelu_and_mul",
|
||||
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
|
||||
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
|
||||
MoEActivation.RELU2: "relu2",
|
||||
MoEActivation.SILU_NO_MUL: "silu_and_mul",
|
||||
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
|
||||
MoEActivation.RELU2_NO_MUL: "relu2",
|
||||
}
|
||||
|
||||
_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
|
||||
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
|
||||
}
|
||||
|
||||
|
||||
def activation_without_mul(activation: str) -> str:
|
||||
"""Get the non-gated variant of an activation function.
|
||||
|
||||
Args:
|
||||
activation: The activation function name (e.g., "silu", "gelu")
|
||||
|
||||
Returns:
|
||||
The non-gated activation name (e.g., "silu_no_mul", "gelu_no_mul")
|
||||
"""
|
||||
return MoEActivation.from_str(activation).without_mul().value
|
||||
|
||||
|
||||
def apply_moe_activation(
|
||||
activation: MoEActivation,
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Apply MoE activation function."""
|
||||
assert input.dim() == 2, "Input must be 2D"
|
||||
assert output.dim() == 2, "Output must be 2D"
|
||||
if activation.is_gated:
|
||||
assert output.size(-1) * 2 == input.size(-1), (
|
||||
f"{activation.value} expects 2x ratio: "
|
||||
f"{output.size(-1) * 2} vs {input.size(-1)}"
|
||||
)
|
||||
else:
|
||||
assert output.size(-1) == input.size(-1), (
|
||||
f"{activation.value} expects equal sizes: "
|
||||
f"{output.size(-1)} vs {input.size(-1)}"
|
||||
)
|
||||
|
||||
# Activations with gated multiplication (gate × activation(up))
|
||||
if activation == MoEActivation.SILU:
|
||||
# torch.ops._C.silu_and_mul(output, input)
|
||||
silu_and_mul(output, input)
|
||||
elif activation == MoEActivation.GELU:
|
||||
# torch.ops._C.gelu_and_mul(output, input)
|
||||
gelu_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUOAI:
|
||||
# torch.ops._C.swigluoai_and_mul(output, input)
|
||||
swigluoai_and_mul(output, input)
|
||||
elif activation == MoEActivation.SWIGLUSTEP:
|
||||
from vllm.model_executor.layers.activation import swiglustep_and_mul_triton
|
||||
|
||||
swiglustep_and_mul_triton(output, input)
|
||||
|
||||
# Activations without gated multiplication
|
||||
elif activation == MoEActivation.SILU_NO_MUL:
|
||||
output.copy_(F.silu(input))
|
||||
elif activation == MoEActivation.GELU_NO_MUL:
|
||||
output.copy_(F.gelu(input))
|
||||
elif activation == MoEActivation.RELU2_NO_MUL:
|
||||
F.relu(input, inplace=True)
|
||||
torch.square(input, out=output)
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
return output
|
||||
254
vllm/model_executor/layers/fused_moe/all2all_utils.py
Normal file
254
vllm/model_executor/layers/fused_moe/all2all_utils.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import (
|
||||
get_ep_group,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_a2a_prepare_finalize import (
|
||||
FlashInferA2APrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEPrepareAndFinalize,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNaiveEP,
|
||||
MoEPrepareAndFinalizeNoEP,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import has_deep_ep, has_mori, has_pplx
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if current_platform.is_cuda_alike():
|
||||
if has_pplx():
|
||||
from .pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize,
|
||||
pplx_hidden_dim_scale_bytes,
|
||||
)
|
||||
if has_deep_ep():
|
||||
from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize
|
||||
from .deepep_ll_prepare_finalize import (
|
||||
DEEPEP_QUANT_BLOCK_SHAPE,
|
||||
DeepEPLLPrepareAndFinalize,
|
||||
)
|
||||
if has_mori():
|
||||
from .mori_prepare_finalize import MoriPrepareAndFinalize
|
||||
|
||||
|
||||
def maybe_roundup_layer_hidden_size(
|
||||
hidden_size: int,
|
||||
act_dtype: torch.dtype,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
) -> int:
|
||||
"""
|
||||
Given layer hidden size and MoE configurations, round up hidden_size
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
hidden_size: Layer hidden-size
|
||||
act_dtype: Data type of the layer activations.
|
||||
moe_parallel_config: Fused MoE parallelization strategy configuration.
|
||||
|
||||
Return:
|
||||
Rounded up hidden_size if rounding up is required based on the configs
|
||||
and all2all backend.
|
||||
Original hidden size otherwise.
|
||||
"""
|
||||
if moe_parallel_config.use_deepep_ht_kernels:
|
||||
hidden_size = DeepEPHTPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size, act_dtype
|
||||
)
|
||||
|
||||
if moe_parallel_config.use_deepep_ll_kernels:
|
||||
hidden_size = DeepEPLLPrepareAndFinalize.maybe_roundup_layer_hidden_size(
|
||||
hidden_size
|
||||
)
|
||||
|
||||
return hidden_size
|
||||
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
allow_new_interface: bool = False,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
# NOTE(rob): we are migrating each quant_method to hold the MK
|
||||
# in all cases. The allow_new_interface=False flag allow us to fall
|
||||
# back to the old method for methods that have not yet been migrated.
|
||||
#
|
||||
# In old method:
|
||||
# * maybe_init_modular_kernel() calls this function. If we are
|
||||
# using no Dp/Ep or naive all2all, we return None this function
|
||||
# returns None and no ModularKernelMethod is created. If non-naive
|
||||
# all2all is used, this returns a PrepareAndFinalize object and
|
||||
# a ModularKernelMethod is created.
|
||||
# In new method:
|
||||
# * maybe_make_prepare_finalize() is called from the oracle. We
|
||||
# always return a PrepareAndFinalize object and the quant method
|
||||
# holds the ModularKernel.
|
||||
if not moe.moe_parallel_config.use_all2all_kernels:
|
||||
if not allow_new_interface:
|
||||
return None
|
||||
|
||||
# For DP/TP case, fall back to naive P/F.
|
||||
if moe.moe_parallel_config.dp_size > 1:
|
||||
logger.info_once(
|
||||
"Detected DP deployment with no --enable-expert-parallel. "
|
||||
"Falling back to AllGather+ReduceScatter dispatch/combine."
|
||||
)
|
||||
return MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
|
||||
num_dispatchers=(
|
||||
get_ep_group().device_communicator.all2all_manager.world_size
|
||||
),
|
||||
)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
all2all_manager = get_ep_group().device_communicator.all2all_manager
|
||||
assert all2all_manager is not None
|
||||
|
||||
prepare_finalize: FusedMoEPrepareAndFinalize | None = None
|
||||
|
||||
if moe.use_pplx_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes(
|
||||
moe.max_num_tokens,
|
||||
moe.hidden_dim,
|
||||
moe.in_dtype,
|
||||
quant_config.quant_dtype,
|
||||
per_act_token_quant=quant_config.per_act_token_quant,
|
||||
block_shape=quant_config.block_shape,
|
||||
)
|
||||
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_experts=moe.num_experts,
|
||||
experts_per_token=moe.experts_per_token, # topk
|
||||
rank=all2all_manager.rank,
|
||||
world_size=all2all_manager.world_size,
|
||||
# dp_size actually means tp_size, bug in pplx kernels
|
||||
dp_size=all2all_manager.tp_group.world_size,
|
||||
hidden_dim=moe.hidden_dim,
|
||||
hidden_dim_bytes=hidden_dim_bytes,
|
||||
hidden_dim_scale_bytes=hidden_scale_bytes,
|
||||
)
|
||||
|
||||
num_dispatchers = (
|
||||
all2all_manager.world_size // all2all_manager.tp_group.world_size
|
||||
)
|
||||
|
||||
# Intranode pplx a2a takes a group name while internode does not.
|
||||
if not all2all_manager.internode:
|
||||
all_to_all_args["group_name"] = all2all_manager.cpu_group.group_name
|
||||
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = PplxPrepareAndFinalize(
|
||||
handle,
|
||||
max_num_tokens=moe.max_num_tokens,
|
||||
num_local_experts=moe.num_local_experts,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
elif moe.use_deepep_ht_kernels:
|
||||
assert moe.dp_size == all2all_manager.dp_world_size
|
||||
|
||||
all_to_all_args = dict()
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
prepare_finalize = DeepEPHTPrepareAndFinalize(
|
||||
handle,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
dp_size=all2all_manager.dp_world_size,
|
||||
rank_expert_offset=all2all_manager.rank * moe.num_local_experts,
|
||||
)
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
assert quant_config is not None
|
||||
global_to_physical = physical_to_global = local_expert_global_ids = None
|
||||
if routing_tables is not None:
|
||||
(
|
||||
global_to_physical,
|
||||
physical_to_global,
|
||||
local_expert_global_ids,
|
||||
) = routing_tables
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
num_global_experts=moe.num_experts,
|
||||
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||
)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.quant_dtype == current_platform.fp8_dtype()
|
||||
and quant_config.block_shape == DEEPEP_QUANT_BLOCK_SHAPE
|
||||
)
|
||||
|
||||
prepare_finalize = DeepEPLLPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
global_to_physical=global_to_physical,
|
||||
physical_to_global=physical_to_global,
|
||||
local_expert_global_ids=local_expert_global_ids,
|
||||
)
|
||||
elif moe.use_mori_kernels:
|
||||
assert quant_config is not None
|
||||
|
||||
# Note: We may want to use FP8 dispatch just to reduce
|
||||
# data movement.
|
||||
use_fp8_dispatch = (
|
||||
quant_config.is_per_act_token or quant_config.is_block_quantized
|
||||
)
|
||||
# For PTPC (per token per channel) quant, the scale dim for each token is 1
|
||||
# For 1x128 quant, the scale dim for each token is hidden_dim // 128
|
||||
scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128
|
||||
all_to_all_args = dict(
|
||||
rank=all2all_manager.rank,
|
||||
num_ep_ranks=all2all_manager.world_size,
|
||||
quant_dtype=quant_config.quant_dtype,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
scale_dim=scale_dim,
|
||||
scale_type_size=torch.float32.itemsize,
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
input_dtype=moe.in_dtype,
|
||||
num_local_experts=moe.num_experts // all2all_manager.world_size,
|
||||
num_experts_per_token=moe.experts_per_token,
|
||||
)
|
||||
handle = all2all_manager.get_handle(all_to_all_args)
|
||||
|
||||
prepare_finalize = MoriPrepareAndFinalize(
|
||||
handle,
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
)
|
||||
|
||||
elif moe.use_fi_all2allv_kernels:
|
||||
assert quant_config is not None
|
||||
prepare_finalize = FlashInferA2APrepareAndFinalize(
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
elif moe.use_naive_all2all_kernels and allow_new_interface:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNaiveEP(
|
||||
is_sequence_parallel=(moe.moe_parallel_config.is_sequence_parallel),
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
447
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Normal file
447
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Normal file
@@ -0,0 +1,447 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
QuantKey,
|
||||
kFp8Dynamic128Sym,
|
||||
kFp8Static128BlockSym,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (
|
||||
DeepGemmQuantScaleFMT,
|
||||
fp8_m_grouped_gemm_nt_masked,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
is_deep_gemm_supported,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def scales_shape_stride_dtype(
|
||||
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
shape = (E, T, G)
|
||||
strides = (T * G, 1, T)
|
||||
if quant_scale_fmt in [
|
||||
DeepGemmQuantScaleFMT.FLOAT32,
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
]:
|
||||
return shape, strides, torch.float32
|
||||
|
||||
assert quant_scale_fmt == DeepGemmQuantScaleFMT.UE8M0
|
||||
shape = (E, T, cdiv(G, 4))
|
||||
strides = (T * cdiv(G, 4), 1, T)
|
||||
return shape, strides, torch.int32
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||
y_s_ptr, # 16-bit scales (E, T, G)
|
||||
counts_ptr, # int32 num tokens per expert (E)
|
||||
# Sizes ---------------------------------------------------------------
|
||||
H: tl.constexpr, # hidden dimension (per output)
|
||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||
# Strides for input (elements) ---------------------------------------
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
# Strides for y_q (elements) -----------------------------------------
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
# Strides for y_s (elements) -----------------------------------------
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
# Stride for counts (elements)
|
||||
stride_counts_e,
|
||||
# Numeric params ------------------------------------------------------
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
ceil_ue8m0: tl.constexpr,
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGES: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
# map program id -> (e, g)
|
||||
pid = tl.program_id(0)
|
||||
e = pid // G
|
||||
g = pid % G
|
||||
|
||||
e = e.to(tl.int64)
|
||||
g = g.to(tl.int64)
|
||||
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||
mask = cols < BLOCK
|
||||
|
||||
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||
base_yq_offset = e * stride_yq_e + g * GROUP_SIZE * stride_yq_h + cols * stride_yq_h
|
||||
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||
|
||||
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||
gate = tl.load(
|
||||
input_ptr + base_gate_offset + t * stride_i_t, mask=mask, other=0.0
|
||||
).to(tl.float32)
|
||||
up = tl.load(input_ptr + base_up_offset + t * stride_i_t, mask=mask, other=0.0)
|
||||
|
||||
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
if ceil_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def persistent_masked_m_silu_mul_quant(
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
group_size: int = 128,
|
||||
quant_scale_fmt: DeepGemmQuantScaleFMT = DeepGemmQuantScaleFMT.FLOAT32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
We launch a fixed grid of threads to accommodate CUDA graphs. Let `P2`
|
||||
be a parallelization factor for persistent_masked_m_silu_mul_quant over the
|
||||
hidden dimension.
|
||||
|
||||
Let `expert_offsets = [0] + [num_tokens.cumsum()]` and
|
||||
`total_tokens = expert_offsets[-1]`.
|
||||
persistent_masked_m_silu_mul_quant launches `total_tokens x P2` number of
|
||||
thread blocks. Each thread block contains `NUM_WARPS` warps.
|
||||
|
||||
Every thread block needs to find it's corresponding expert by warp-parallel scanning
|
||||
over the `expert_offsets` array.
|
||||
|
||||
The i-th warp in the first thread block processes
|
||||
`[i * warp_chunk_size, (i + 1) * warp_chunk_size]` groups
|
||||
sequentially, where `warp_chunk_size = ((H / GROUP_SIZE) / P2) / NUM_WARPS`,
|
||||
pipelining loads and computes.
|
||||
|
||||
The shared memory layout for 4 warps with a 2-stage pipeline for SiLU V2
|
||||
can is visualized like so:
|
||||
|
||||
stage0 stage1
|
||||
┌─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┬─────┬───┐
|
||||
│gate0│up0│gate1│up1│gate2│up2│gate3│up3│gate0│up0│gate1│up1│gate2│up2│gate3│up3│
|
||||
└─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┴─────┴───┘
|
||||
|
||||
with the main difference between V1 and V2 being the global load
|
||||
stride between warps, and between half-warps. Regarding the latter stride,
|
||||
we assign the first half warp of every warp for `gate` loads and the second
|
||||
half-warp to `up` loads.
|
||||
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s` depends on quant_scale_fmt,
|
||||
- quant_scale_fmt == FLOAT32,
|
||||
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
- quant_scale_fmt == E8M0,
|
||||
`y_s`: Int32 tensor, shape (E, T, H // group_size // 4), strides (T*G, 1, T)
|
||||
- quant_scale_fmt == E8M0_FLOAT32_SPARSE
|
||||
`y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
Let NUM_WARPS be the number of warps in a single thread block and
|
||||
`GROUP_SIZE = 128` be the size of the quantization group.
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||
H = H2 // 2
|
||||
G = (H + group_size - 1) // group_size
|
||||
assert H % 8 == 0, "H must be divisible by 8"
|
||||
assert group_size == 128, "H must be divisible by 8"
|
||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E
|
||||
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device, dtype=torch.int32)
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
ys_shape, ys_strides, ys_dtype = scales_shape_stride_dtype(E, T, G, quant_scale_fmt)
|
||||
y_s = torch.empty_strided(
|
||||
ys_shape,
|
||||
ys_strides,
|
||||
dtype=ys_dtype,
|
||||
device=y.device,
|
||||
)
|
||||
|
||||
ceil_ue8m0 = quant_scale_fmt in [
|
||||
DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0,
|
||||
DeepGemmQuantScaleFMT.UE8M0,
|
||||
]
|
||||
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index
|
||||
).to_int()
|
||||
|
||||
if cuda_arch >= 80:
|
||||
torch.ops._C.persistent_masked_m_silu_mul_quant(
|
||||
y, tokens_per_expert, y_q, y_s, ceil_ue8m0
|
||||
)
|
||||
else:
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G,)
|
||||
# strides (elements)
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
eps: float = 1e-10
|
||||
assert y_s.dtype == torch.float32, (
|
||||
"_silu_mul_fp8_quant_deep_gemm does"
|
||||
"not support {y_s.dtype} scales. Only torch.float32 supported."
|
||||
)
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
y_s,
|
||||
tokens_per_expert,
|
||||
H,
|
||||
group_size,
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
ys_strides[0],
|
||||
ys_strides[1],
|
||||
ys_strides[2],
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
ceil_ue8m0,
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=4,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
moe_config: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
):
|
||||
"""
|
||||
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||
num_dispatchers: The number of DP dispatchers.
|
||||
quant_config: Quantization configuration
|
||||
"""
|
||||
super().__init__(
|
||||
moe_config=moe_config,
|
||||
quant_config=quant_config,
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
)
|
||||
assert self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
assert self.quant_config.use_fp8_w8a8
|
||||
|
||||
@staticmethod
|
||||
def activation_format() -> mk.FusedMoEActivationFormat:
|
||||
return mk.FusedMoEActivationFormat.BatchedExperts
|
||||
|
||||
@staticmethod
|
||||
def _supports_current_device() -> bool:
|
||||
return is_deep_gemm_supported()
|
||||
|
||||
@staticmethod
|
||||
def _supports_no_act_and_mul() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _supports_quant_scheme(
|
||||
weight_key: QuantKey | None,
|
||||
activation_key: QuantKey | None,
|
||||
) -> bool:
|
||||
SUPPORTED_W_A = [(kFp8Static128BlockSym, kFp8Dynamic128Sym)]
|
||||
return (weight_key, activation_key) in SUPPORTED_W_A
|
||||
|
||||
@staticmethod
|
||||
def _supports_activation(activation: MoEActivation) -> bool:
|
||||
return activation == MoEActivation.SILU
|
||||
|
||||
@staticmethod
|
||||
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
|
||||
return True
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_packed_ue8m0_act_scales(self) -> bool:
|
||||
"""
|
||||
DeepGemm supports packed ue8m0 activation scales format in devices == sm100
|
||||
"""
|
||||
return (
|
||||
is_deep_gemm_e8m0_used()
|
||||
and current_platform.is_device_capability_family(100)
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
activation: MoEActivation,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
# end up sending their tokens. This needs to be fixed.
|
||||
assert self.num_dispatchers is not None
|
||||
assert self.max_num_tokens is not None
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens
|
||||
activation_out_dim = self.adjust_N_for_activation(N, activation)
|
||||
workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, activation_out_dim)
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def estimate_expected_m(
|
||||
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
|
||||
) -> int:
|
||||
dp_meta = (
|
||||
get_forward_context().dp_metadata
|
||||
if is_forward_context_available()
|
||||
else None
|
||||
)
|
||||
if dp_meta is None:
|
||||
logger.warning_once(
|
||||
"DPMetadata unavailable. Defaulting expected_m to "
|
||||
f"{max_tokens_per_expert}.",
|
||||
scope="local",
|
||||
)
|
||||
return max_tokens_per_expert
|
||||
|
||||
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
|
||||
total_num_tokens_replicated = total_num_tokens * topk
|
||||
|
||||
# Assume even load balancing
|
||||
assert global_num_experts != 0
|
||||
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
|
||||
# clamp estimate
|
||||
estimate = max(estimate, 16)
|
||||
estimate = min(max_tokens_per_expert, estimate)
|
||||
return estimate
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: MoEActivation,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.block_shape is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
assert w2.size(1) == K
|
||||
|
||||
E, max_num_tokens, N, K, _ = self.moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids
|
||||
)
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
|
||||
expected_m = self.estimate_expected_m(
|
||||
global_num_experts=global_num_experts,
|
||||
max_tokens_per_expert=max_num_tokens,
|
||||
topk=topk_ids.size(-1),
|
||||
)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked(
|
||||
(a1q, a1q_scale),
|
||||
(w1, self.w1_scale),
|
||||
workspace1,
|
||||
expert_num_tokens,
|
||||
expected_m,
|
||||
)
|
||||
|
||||
quant_scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
|
||||
a2q, a2q_scale = persistent_masked_m_silu_mul_quant(
|
||||
workspace1,
|
||||
expert_num_tokens,
|
||||
quant_scale_fmt=quant_scale_fmt,
|
||||
)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked(
|
||||
(a2q, a2q_scale),
|
||||
(w2, self.w2_scale),
|
||||
output,
|
||||
expert_num_tokens,
|
||||
expected_m,
|
||||
)
|
||||
1246
vllm/model_executor/layers/fused_moe/config.py
Normal file
1246
vllm/model_executor/layers/fused_moe/config.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
{
|
||||
"triton_version": "3.4.0",
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
{
|
||||
"triton_version": "3.5.1",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"320": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"768": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.5.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.6.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
{
|
||||
"triton_version": "3.5.0",
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
{
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user