first commit
This commit is contained in:
0
vllm/model_executor/layers/__init__.py
Normal file
0
vllm/model_executor/layers/__init__.py
Normal file
BIN
vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/mla.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/mla.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/pooler.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/pooler.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/resampler.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/resampler.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
575
vllm/model_executor/layers/activation.py
Normal file
575
vllm/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,575 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Custom activation functions."""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import LazyDict
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@CustomOp.register("fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
|
||||
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||
d = x.shape[-1] // 2.
|
||||
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: float = 0.):
|
||||
super().__init__()
|
||||
self.threshold = threshold
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.fatrelu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
x1 = x[..., :d]
|
||||
x2 = x[..., d:]
|
||||
x1 = F.threshold(x1, self.threshold, 0.0)
|
||||
return x1 * x2
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x, self.threshold)
|
||||
return out
|
||||
|
||||
|
||||
@CustomOp.register("silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.silu_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
|
||||
@CustomOp.register("mul_and_silu")
|
||||
class MulAndSilu(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> x[:d] * silu(x[d:]) where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike():
|
||||
self.op = torch.ops._C.mul_and_silu
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.silu_and_mul
|
||||
elif current_platform.is_cpu():
|
||||
self._forward_method = self.forward_native
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return x[..., :d] * F.silu(x[..., d:])
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for MulAndSilu
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("gelu_and_mul_sparse")
|
||||
class GeluAndMulSparse(CustomOp):
|
||||
"""An activation function for GeluAndMulSparse.
|
||||
This activation function is used in Gemma3n. It computes:
|
||||
up_proj = self.up_proj(x)
|
||||
gate_proj = self.gate_proj(x)
|
||||
gate_proj = self._gaussian_topk(gate_proj) # sparsity
|
||||
activations = self.act_fn(gate_proj) # gelu
|
||||
down_proj = self.down_proj(activations * up_proj)
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
|
||||
return: (num_tokens, d) or (batch_size, seq_len, d)
|
||||
"""
|
||||
|
||||
def __init__(self, activation_sparsity: float, approximate: str = "none"):
|
||||
super().__init__()
|
||||
# Gelu.
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
|
||||
# Sparsity.
|
||||
if activation_sparsity == 0.0:
|
||||
raise ValueError(
|
||||
"activation_sparsity is 0.0. Please use GeluAndMul.")
|
||||
target_sparsity_tensor = torch.tensor(activation_sparsity,
|
||||
dtype=torch.float32)
|
||||
normal_dist = torch.distributions.normal.Normal(0, 1)
|
||||
self.std_multiplier = normal_dist.icdf(target_sparsity_tensor)
|
||||
|
||||
def _gaussian_topk(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Get % sparse percentile of the Gaussian distribution."""
|
||||
# NOTE(rob): for TP>1, we could all-gather to get the means/std.
|
||||
# But we do not do this because in expectation they are the same
|
||||
# and in practice the eval scores are good without gathering.
|
||||
mean = torch.mean(x, dim=-1, keepdim=True)
|
||||
std = torch.std(x, dim=-1, keepdim=True, unbiased=False)
|
||||
cutoff_x = mean + std * self.std_multiplier
|
||||
return nn.functional.relu(x - cutoff_x)
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
out = self._gaussian_topk(x[..., :d])
|
||||
out = F.gelu(out, approximate=self.approximate)
|
||||
return out * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def __init__(self, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.approximate = approximate
|
||||
if approximate not in ("none", "tanh"):
|
||||
raise ValueError(f"Unknown approximate mode: {approximate}")
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
if approximate == "none":
|
||||
self.op = torch.ops._C.gelu_and_mul
|
||||
elif approximate == "tanh":
|
||||
self.op = torch.ops._C.gelu_tanh_and_mul
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
if approximate == "none":
|
||||
self.op = ipex_ops.gelu_and_mul
|
||||
else:
|
||||
self.op = ipex_ops.gelu_tanh_and_mul
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("swigluoai_and_mul")
|
||||
class SwigluOAIAndMul(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/v4.55.0/src/transformers/models/gpt_oss/modeling_gpt_oss.py#L106-L110
|
||||
def __init__(self, alpha: float = 1.702, limit: float = 7.0):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.limit = limit
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
|
||||
gate, up = x[..., ::2], x[..., 1::2]
|
||||
gate = gate.clamp(min=None, max=self.limit)
|
||||
up = up.clamp(min=-self.limit, max=self.limit)
|
||||
glu = gate * torch.sigmoid(gate * self.alpha)
|
||||
gated_output = (up + 1) * glu
|
||||
return gated_output
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
torch.ops._C.swigluoai_and_mul(out, x, self.alpha, self.limit)
|
||||
return out
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"alpha={repr(self.alpha)}, limit={repr(self.limit)}"
|
||||
|
||||
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
self.op = torch.ops._C.gelu_new
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.gelu_new
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
@CustomOp.register("gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
self.op = torch.ops._C.gelu_fast
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.gelu_fast
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.op(x)
|
||||
|
||||
|
||||
@CustomOp.register("quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda_alike() or current_platform.is_cpu():
|
||||
self.op = torch.ops._C.gelu_quick
|
||||
elif current_platform.is_xpu():
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
self.op = ipex_ops.gelu_quick
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
self.op(out, x)
|
||||
return out
|
||||
|
||||
# TODO implement forward_xpu for QuickGELU
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
"""
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return torch.square(F.relu(x))
|
||||
|
||||
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
#TODO : implement cuda kernels
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("xielu")
|
||||
class XIELU(CustomOp):
|
||||
"""
|
||||
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
||||
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
|
||||
Otherwise, we emit a single warning and use xIELU Python
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
alpha_p_init: float = 0.8,
|
||||
alpha_n_init: float = 0.8,
|
||||
beta: float = 0.5,
|
||||
eps: float = -1e-6,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
with_vector_loads: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.alpha_p = nn.Parameter(
|
||||
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) -
|
||||
1).unsqueeze(0))
|
||||
self.alpha_n = nn.Parameter(
|
||||
torch.log(
|
||||
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) -
|
||||
1).unsqueeze(0))
|
||||
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
||||
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
||||
self.with_vector_loads = with_vector_loads
|
||||
# Temporary until xIELU CUDA fully implemented
|
||||
self._beta_scalar = float(self.beta.detach().cpu().float().item())
|
||||
self._eps_scalar = float(self.eps.detach().cpu().float().item())
|
||||
|
||||
self._xielu_cuda_obj = None
|
||||
try:
|
||||
import xielu.ops # noqa: F401
|
||||
|
||||
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
||||
msg = "Using experimental xIELU CUDA."
|
||||
try:
|
||||
from torch._dynamo import allow_in_graph
|
||||
|
||||
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
||||
msg += " Enabled torch._dynamo for xIELU CUDA."
|
||||
except Exception as err:
|
||||
msg += (f" Could not enable torch._dynamo for xIELU ({err}) - "
|
||||
"this may result in slower performance.")
|
||||
self._xielu_cuda_fn = self._xielu_cuda
|
||||
logger.warning_once(msg)
|
||||
except Exception as err:
|
||||
logger.warning_once(
|
||||
"CUDA-fused xIELU not available (%s) –"
|
||||
" falling back to a Python version.\n"
|
||||
"For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
||||
str(err),
|
||||
)
|
||||
|
||||
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
|
||||
alpha_p = nn.functional.softplus(self.alpha_p)
|
||||
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
||||
return torch.where(
|
||||
x > 0,
|
||||
alpha_p * x * x + self.beta * x,
|
||||
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n +
|
||||
self.beta * x,
|
||||
)
|
||||
|
||||
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Firewall function to prevent torch.compile from seeing .item()"""
|
||||
assert self._xielu_cuda_obj is not None, (
|
||||
"XIELU CUDA object must not be None")
|
||||
original_shape = x.shape
|
||||
# CUDA kernel expects 3D tensors, reshape if needed
|
||||
while x.dim() < 3:
|
||||
x = x.unsqueeze(0)
|
||||
if x.dim() > 3:
|
||||
x = x.view(-1, 1, x.size(-1))
|
||||
if original_shape != x.shape:
|
||||
logger.warning_once(
|
||||
"Warning: xIELU input tensor expects 3 dimensions"
|
||||
" but got (shape: %s). Reshaping to (shape: %s).",
|
||||
original_shape,
|
||||
x.shape,
|
||||
)
|
||||
result = self._xielu_cuda_obj.forward(
|
||||
x,
|
||||
self.alpha_p,
|
||||
self.alpha_n,
|
||||
# Temporary until xIELU CUDA fully implemented ->
|
||||
# self.{beta,eps}.item()
|
||||
self._beta_scalar,
|
||||
self._eps_scalar,
|
||||
self.with_vector_loads,
|
||||
)
|
||||
return result.view(original_shape)
|
||||
|
||||
def forward_native(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self._xielu_cuda_obj is not None and input.is_cuda:
|
||||
if not torch._dynamo.is_compiling():
|
||||
return self._xielu_cuda_fn(input)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"torch._dynamo is compiling, using Python version of xIELU."
|
||||
)
|
||||
return self._xielu_python(input)
|
||||
|
||||
def forward_cuda(self, input: torch.Tensor) -> torch.Tensor:
|
||||
return self.forward_native(input)
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = LazyDict({
|
||||
"gelu":
|
||||
lambda: nn.GELU(),
|
||||
"gelu_fast":
|
||||
lambda: FastGELU(),
|
||||
"gelu_new":
|
||||
lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh":
|
||||
lambda: nn.GELU(approximate="tanh"),
|
||||
"relu":
|
||||
lambda: nn.ReLU(),
|
||||
"relu2":
|
||||
lambda: ReLUSquaredActivation(),
|
||||
"silu":
|
||||
lambda: nn.SiLU(),
|
||||
"quick_gelu":
|
||||
lambda: QuickGELU(),
|
||||
"tanh":
|
||||
lambda: nn.Tanh(),
|
||||
"sigmoid":
|
||||
lambda: nn.Sigmoid(),
|
||||
"xielu":
|
||||
lambda: XIELU(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
|
||||
if act_fn_name.startswith("torch.nn.modules."):
|
||||
activation_name = act_fn_name.split(".")[-1]
|
||||
if activation_name == "identity":
|
||||
return nn.Identity()
|
||||
act_fn_name = activation_name
|
||||
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
return _ACTIVATION_REGISTRY[act_fn_name]
|
||||
|
||||
|
||||
_ACTIVATION_AND_MUL_REGISTRY = LazyDict({
|
||||
"gelu":
|
||||
lambda: GeluAndMul(),
|
||||
"silu":
|
||||
lambda: SiluAndMul(),
|
||||
"geglu":
|
||||
lambda: GeluAndMul(),
|
||||
"swigluoai":
|
||||
lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
|
||||
})
|
||||
|
||||
|
||||
def get_act_and_mul_fn(act_fn_name: str) -> nn.Module:
|
||||
"""Get an activation-and-mul (i.e. SiluAndMul) function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_AND_MUL_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
return _ACTIVATION_AND_MUL_REGISTRY[act_fn_name]
|
||||
23
vllm/model_executor/layers/attention_layer_base.py
Normal file
23
vllm/model_executor/layers/attention_layer_base.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Base class for attention-like layers."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class AttentionLayerBase(ABC):
|
||||
"""
|
||||
Base class for attention-like layers (Attention, Mamba, etc.)
|
||||
that support the v1 engine.
|
||||
|
||||
This provides a common interface for getting attention backends
|
||||
from different layer types.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this layer."""
|
||||
pass
|
||||
8
vllm/model_executor/layers/fla/__init__.py
Normal file
8
vllm/model_executor/layers/fla/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
Binary file not shown.
17
vllm/model_executor/layers/fla/ops/__init__.py
Normal file
17
vllm/model_executor/layers/fla/ops/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
from .chunk import chunk_gated_delta_rule
|
||||
from .fused_recurrent import fused_recurrent_gated_delta_rule
|
||||
from .layernorm_guard import RMSNormGated
|
||||
|
||||
__all__ = [
|
||||
"RMSNormGated",
|
||||
"chunk_gated_delta_rule",
|
||||
"fused_recurrent_gated_delta_rule",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
225
vllm/model_executor/layers/fla/ops/chunk.py
Normal file
225
vllm/model_executor/layers/fla/ops/chunk.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
|
||||
from .chunk_o import chunk_fwd_o
|
||||
from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
output_dtype=torch.float32)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
w=w,
|
||||
u=u,
|
||||
g=g,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_new,
|
||||
h=h,
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
elif SUPPRESS_LEVEL >= 3:
|
||||
return g, o, A, final_state, w, h, v_new
|
||||
|
||||
|
||||
class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@input_guard
|
||||
@torch.amp.custom_fwd(device_type='cuda')
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
q = l2norm_fwd(q)
|
||||
k = l2norm_fwd(k)
|
||||
|
||||
g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
return o.to(q.dtype), final_state
|
||||
|
||||
|
||||
@torch.compiler.disable
|
||||
def chunk_gated_delta_rule(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
head_first: bool = False,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
g (torch.Tensor):
|
||||
(forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, H, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
output_final_state (Optional[bool]):
|
||||
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
head_first (Optional[bool]):
|
||||
Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
|
||||
Default: `False`.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, K, V = 4, 2048, 4, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
|
||||
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
|
||||
>>> o, ht = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = chunk_gated_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
output_final_state=True,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
|
||||
assert len(
|
||||
beta.shape
|
||||
) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
|
||||
|
||||
if head_first:
|
||||
raise DeprecationWarning(
|
||||
"head_first is deprecated and will be removed in a future version. "
|
||||
"Please use head_first=False for now instead.",
|
||||
stacklevel=2)
|
||||
q, k, v, beta, g = map(
|
||||
lambda x: rearrange(x, 'b h t ... -> b t h ...'),
|
||||
(q, k, v, beta, g))
|
||||
if not head_first and q.shape[1] < q.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
if q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if initial_state is not None and initial_state.shape[0] != len(
|
||||
cu_seqlens) - 1:
|
||||
raise ValueError(
|
||||
f"The number of initial states is expected to be equal to the number of input sequences, "
|
||||
f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
|
||||
)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
o, final_state = ChunkGatedDeltaRuleFunction.apply(
|
||||
q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens,
|
||||
use_qk_l2norm_in_kernel)
|
||||
if head_first:
|
||||
o = rearrange(o, 'b t h ... -> b h t ...')
|
||||
return o, final_state
|
||||
290
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Normal file
290
vllm/model_executor/layers/fla/ops/chunk_delta_h.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp
|
||||
from .utils import is_nvidia_hopper, use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'USE_INITIAL_STATE': lambda args: args['h0'] is not None,
|
||||
'STORE_FINAL_STATE': lambda args: args['ht'] is not None,
|
||||
'SAVE_NEW_VALUE': lambda args: args['v_new'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BV': BV}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4] for num_stages in [2, 3, 4] for BV in [32, 64]
|
||||
],
|
||||
key=['H', 'K', 'V', 'BT', 'USE_G'],
|
||||
use_cuda_graph=use_cuda_graph,
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
|
||||
k,
|
||||
v,
|
||||
w,
|
||||
v_new,
|
||||
g,
|
||||
h,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
chunk_offsets,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr,
|
||||
STORE_FINAL_STATE: tl.constexpr,
|
||||
SAVE_NEW_VALUE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_nh = tl.program_id(0), tl.program_id(1)
|
||||
i_n, i_h = i_nh // H, i_nh % H
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = tl.load(chunk_offsets + i_n).to(tl.int32)
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
NT = tl.cdiv(T, BT)
|
||||
boh = i_n * NT
|
||||
|
||||
# [BK, BV]
|
||||
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 64:
|
||||
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 128:
|
||||
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
if K > 192:
|
||||
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
|
||||
|
||||
# calculate offset
|
||||
h += (boh * H + i_h) * K * V
|
||||
v += (bos * H + i_h) * V
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
w += (bos * H + i_h) * K
|
||||
if SAVE_NEW_VALUE:
|
||||
v_new += (bos * H + i_h) * V
|
||||
stride_v = H * V
|
||||
stride_h = H * K * V
|
||||
stride_k = Hg * K
|
||||
stride_w = H * K
|
||||
if USE_INITIAL_STATE:
|
||||
h0 = h0 + i_nh * K * V
|
||||
if STORE_FINAL_STATE:
|
||||
ht = ht + i_nh * K * V
|
||||
|
||||
# load initial state
|
||||
if USE_INITIAL_STATE:
|
||||
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV),
|
||||
(1, 0))
|
||||
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 64:
|
||||
p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 128:
|
||||
p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
|
||||
if K > 192:
|
||||
p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
# main recurrence
|
||||
for i_t in range(NT):
|
||||
p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(0, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(64, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h2,
|
||||
b_h2.to(p_h2.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(128, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h3,
|
||||
b_h3.to(p_h3.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1),
|
||||
(192, i_v * BV), (64, BV), (1, 0))
|
||||
tl.store(p_h4,
|
||||
b_h4.to(p_h4.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV),
|
||||
(1, 0)) if SAVE_NEW_VALUE else None
|
||||
b_v_new = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype))
|
||||
if K > 64:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype))
|
||||
if K > 128:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype))
|
||||
if K > 192:
|
||||
p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192),
|
||||
(BT, 64), (1, 0))
|
||||
b_w = tl.load(p_w, boundary_check=(0, 1))
|
||||
b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype))
|
||||
b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
if SAVE_NEW_VALUE:
|
||||
p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
tl.store(p_v_new,
|
||||
b_v_new.to(p_v_new.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
if USE_G:
|
||||
m_t = (i_t * BT + tl.arange(0, BT)) < T
|
||||
last_idx = min((i_t + 1) * BT, T) - 1
|
||||
b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
|
||||
p_g = tl.make_block_ptr(g + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None]
|
||||
b_g_last = exp(b_g_last)
|
||||
b_h1 = b_h1 * b_g_last
|
||||
if K > 64:
|
||||
b_h2 = b_h2 * b_g_last
|
||||
if K > 128:
|
||||
b_h3 = b_h3 * b_g_last
|
||||
if K > 192:
|
||||
b_h4 = b_h4 * b_g_last
|
||||
b_v_new = b_v_new.to(k.dtype.element_ty)
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h1 += tl.dot(b_k, b_v_new)
|
||||
if K > 64:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h2 += tl.dot(b_k, b_v_new)
|
||||
if K > 128:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h3 += tl.dot(b_k, b_v_new)
|
||||
if K > 192:
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT),
|
||||
(64, BT), (0, 1))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_h4 += tl.dot(b_k, b_v_new)
|
||||
|
||||
# epilogue
|
||||
if STORE_FINAL_STATE:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV),
|
||||
(1, 0))
|
||||
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
|
||||
if K > 64:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
tl.store(p_ht,
|
||||
b_h2.to(p_ht.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 128:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
tl.store(p_ht,
|
||||
b_h3.to(p_ht.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
if K > 192:
|
||||
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV),
|
||||
(64, BV), (1, 0))
|
||||
tl.store(p_ht,
|
||||
b_h4.to(p_ht.dtype.element_ty),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_gated_delta_rule_fwd_h(
|
||||
k: torch.Tensor,
|
||||
w: torch.Tensor,
|
||||
u: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None,
|
||||
initial_state: Optional[torch.Tensor] = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, u.shape[-1]
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = len(cu_seqlens) - 1, len(
|
||||
chunk_indices), prepare_chunk_offsets(cu_seqlens, BT)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, K, V)
|
||||
final_state = k.new_empty(
|
||||
N, H, K, V, dtype=torch.float32) if output_final_state else None
|
||||
|
||||
v_new = torch.empty_like(u) if save_new_value else None
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), N * H)
|
||||
|
||||
chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
|
||||
k=k,
|
||||
v=u,
|
||||
w=w,
|
||||
v_new=v_new,
|
||||
g=g,
|
||||
h=h,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_offsets=chunk_offsets,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT)
|
||||
return h, v_new, final_state
|
||||
177
vllm/model_executor/layers/fla/ops/chunk_o.py
Normal file
177
vllm/model_executor/layers/fla/ops/chunk_o.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
|
||||
|
||||
BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
|
||||
NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_G': lambda args: args['g'] is not None,
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({
|
||||
'BK': BK,
|
||||
'BV': BV
|
||||
},
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages) for BK in BKV_LIST
|
||||
for BV in BKV_LIST for num_warps in NUM_WARPS
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=['H', 'K', 'V', 'BT'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_fwd_kernel_o(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
|
||||
if IS_VARLEN:
|
||||
i_tg = i_t
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
NT = tl.cdiv(T, BT)
|
||||
else:
|
||||
NT = tl.cdiv(T, BT)
|
||||
i_tg = i_b * NT + i_t
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
# offset calculation
|
||||
q += (bos * Hg + i_h // (H // Hg)) * K
|
||||
k += (bos * Hg + i_h // (H // Hg)) * K
|
||||
v += (bos * H + i_h) * V
|
||||
o += (bos * H + i_h) * V
|
||||
h += (i_tg * H + i_h).to(tl.int64) * K * V
|
||||
|
||||
b_o = tl.zeros([BT, BV], dtype=tl.float32)
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK),
|
||||
(BT, BK), (1, 0))
|
||||
p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT),
|
||||
(BK, BT), (0, 1))
|
||||
p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV),
|
||||
(BK, BV), (1, 0))
|
||||
# [BT, BK]
|
||||
b_q = tl.load(p_q, boundary_check=(0, 1))
|
||||
# [BK, BT]
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
# [BK, BV]
|
||||
b_h = tl.load(p_h, boundary_check=(0, 1))
|
||||
|
||||
# [BT, BK] @ [BK, BV] -> [BT, BV]
|
||||
b_o += tl.dot(b_q, b_h)
|
||||
# [BT, BK] @ [BK, BT] -> [BT, BT]
|
||||
b_A += tl.dot(b_q, b_k)
|
||||
|
||||
if USE_G:
|
||||
g += bos * H + i_h
|
||||
p_g = tl.make_block_ptr(g, (T, ), (H, ), (i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_o = b_o * exp(b_g)[:, None]
|
||||
b_A = b_A * exp(b_g[:, None] - b_g[None, :])
|
||||
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
|
||||
p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV),
|
||||
(BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
|
||||
# to fix mma -> mma layout conversion
|
||||
# already solved by triton v3.2 or higher
|
||||
b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_fwd_o(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
h: torch.Tensor,
|
||||
g: Optional[torch.Tensor] = None, # cumsum of log decay
|
||||
scale: Optional[float] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
if FLA_GDN_FIX_BT:
|
||||
BT = 64
|
||||
else:
|
||||
BT = min(chunk_size, max(16, triton.next_power_of_2(T)))
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
|
||||
o = torch.empty_like(v)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(V, meta['BV']), NT, B * H)
|
||||
|
||||
chunk_fwd_kernel_o[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
h,
|
||||
g,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
scale,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
)
|
||||
return o
|
||||
140
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Normal file
140
vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'IS_VARLEN': lambda args: args['cu_seqlens'] is not None,
|
||||
'USE_G': lambda args: args['g_cumsum'] is not None
|
||||
})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BK': BK}, num_warps=num_warps, num_stages=num_stages)
|
||||
for BK in [32, 64, 128] for num_warps in [2, 4, 8]
|
||||
for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=['H', 'K', 'BT', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_scaled_dot_kkt_fwd_kernel(
|
||||
k,
|
||||
beta,
|
||||
g_cumsum,
|
||||
A,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
Hg: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
USE_G: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
o_t = i_t * BT + tl.arange(0, BT)
|
||||
m_t = o_t < T
|
||||
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
|
||||
b_A = tl.zeros([BT, BT], dtype=tl.float32)
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
||||
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
||||
(1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = b_k * b_beta[:, None]
|
||||
b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k))
|
||||
|
||||
if USE_G:
|
||||
p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
b_g = tl.load(p_g, boundary_check=(0, ))
|
||||
b_g_diff = b_g[:, None] - b_g[None, :]
|
||||
b_A = b_A * exp(b_g_diff)
|
||||
|
||||
m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
|
||||
b_A = tl.where(m_A, b_A, 0)
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_scaled_dot_kkt_fwd(
|
||||
k: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: Optional[torch.Tensor] = None,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
chunk_size: int = 64,
|
||||
output_dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
r"""
|
||||
Compute beta * K * K^T.
|
||||
|
||||
Args:
|
||||
k (torch.Tensor):
|
||||
The key tensor of shape `[B, T, H, K]`.
|
||||
beta (torch.Tensor):
|
||||
The beta tensor of shape `[B, T, H]`.
|
||||
g_cumsum (torch.Tensor):
|
||||
The cumulative sum of the gate tensor of shape `[B, T, H]`.
|
||||
Default: None
|
||||
cu_seqlens (torch.LongTensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float32`
|
||||
|
||||
Returns:
|
||||
beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
|
||||
"""
|
||||
|
||||
B, T, Hg, K = k.shape
|
||||
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
beta=beta,
|
||||
g_cumsum=g_cumsum,
|
||||
A=A,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
BT=BT,
|
||||
)
|
||||
return A
|
||||
226
vllm/model_executor/layers/fla/ops/cumsum.py
Normal file
226
vllm/model_executor/layers/fla/ops/cumsum.py
Normal file
@@ -0,0 +1,226 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import check_shared_mem, input_guard
|
||||
|
||||
BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]
|
||||
],
|
||||
key=['B', 'H', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_scalar_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T, ), (1, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
p_o = tl.make_block_ptr(o + bos * H + i_h, (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
# [BT]
|
||||
b_s = tl.load(p_s, boundary_check=(0, )).to(tl.float32)
|
||||
b_o = tl.cumsum(b_s, axis=0)
|
||||
if REVERSE:
|
||||
b_z = tl.sum(b_s, axis=0)
|
||||
b_o = -b_o + b_z[None] + b_s
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, ))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BS': BS}, num_warps=num_warps) for BS in BS_LIST
|
||||
for num_warps in [2, 4, 8]
|
||||
],
|
||||
key=['B', 'H', 'S', 'BT', 'IS_VARLEN', 'REVERSE'])
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def chunk_local_cumsum_vector_kernel(
|
||||
s,
|
||||
o,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
S: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BS: tl.constexpr,
|
||||
REVERSE: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
HEAD_FIRST: tl.constexpr,
|
||||
):
|
||||
i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
o_i = tl.arange(0, BT)
|
||||
if REVERSE:
|
||||
m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)
|
||||
else:
|
||||
m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)
|
||||
|
||||
if HEAD_FIRST:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
else:
|
||||
p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1),
|
||||
(i_t * BT, i_s * BS), (BT, BS), (1, 0))
|
||||
# [BT, BS]
|
||||
b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_o = tl.dot(m_s, b_s, allow_tf32=False)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def chunk_local_cumsum_scalar(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T = g.shape
|
||||
else:
|
||||
B, T, H = g.shape
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
chunk_local_cumsum_scalar_kernel[grid](g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse)
|
||||
return g
|
||||
|
||||
|
||||
def chunk_local_cumsum_vector(
|
||||
g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float) -> torch.Tensor:
|
||||
if head_first:
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, chunk_size) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2**(chunk_size.bit_length() -
|
||||
1), "chunk_size must be a power of 2"
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)
|
||||
|
||||
# keep cumulative normalizer in fp32
|
||||
# this kernel is equivalent to
|
||||
# g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
|
||||
chunk_local_cumsum_vector_kernel[grid](g_org,
|
||||
g,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
S=S,
|
||||
BT=BT,
|
||||
HEAD_FIRST=head_first,
|
||||
REVERSE=reverse)
|
||||
return g
|
||||
|
||||
|
||||
@input_guard
|
||||
def chunk_local_cumsum(g: torch.Tensor,
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: Optional[torch.dtype] = torch.float,
|
||||
**kwargs) -> torch.Tensor:
|
||||
if not head_first and g.shape[1] < g.shape[2]:
|
||||
warnings.warn(
|
||||
f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). "
|
||||
"This may indicate the inputs were passed in head-first format [B, H, T, ...] "
|
||||
"when head_first=False was specified. "
|
||||
"Please verify your input tensor format matches the expected shape [B, T, H, ...].",
|
||||
stacklevel=2)
|
||||
if cu_seqlens is not None:
|
||||
assert g.shape[
|
||||
0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens,
|
||||
head_first, output_dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported input shape {g.shape}. "
|
||||
f"which should be (B, T, H, D) if `head_first=False` "
|
||||
f"or (B, H, T, D) otherwise")
|
||||
366
vllm/model_executor/layers/fla/ops/fused_recurrent.py
Normal file
366
vllm/model_executor/layers/fla/ops/fused_recurrent.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .op import exp
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
'USE_INITIAL_STATE':
|
||||
lambda args: args['h0'] is not None,
|
||||
'IS_VARLEN':
|
||||
lambda args: args['cu_seqlens'] is not None,
|
||||
"IS_CONTINUOUS_BATCHING":
|
||||
lambda args: args['ssm_state_indices'] is not None,
|
||||
"IS_SPEC_DECODING":
|
||||
lambda args: args['num_accepted_tokens'] is not None,
|
||||
})
|
||||
@triton.jit(do_not_specialize=['N', 'T'])
|
||||
def fused_recurrent_gated_delta_rule_fwd_kernel(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
o,
|
||||
h0,
|
||||
ht,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
scale,
|
||||
N: tl.int64, # num of sequences
|
||||
T: tl.int64, # num of tokens
|
||||
B: tl.constexpr,
|
||||
H: tl.constexpr,
|
||||
HV: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
V: tl.constexpr,
|
||||
BK: tl.constexpr,
|
||||
BV: tl.constexpr,
|
||||
stride_init_state_token: tl.constexpr,
|
||||
stride_final_state_token: tl.constexpr,
|
||||
stride_indices_seq: tl.constexpr,
|
||||
stride_indices_tok: tl.constexpr,
|
||||
USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
|
||||
INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
|
||||
IS_BETA_HEADWISE: tl.
|
||||
constexpr, # whether beta is headwise vector or scalar,
|
||||
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
IS_CONTINUOUS_BATCHING: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
):
|
||||
i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
|
||||
i_n, i_hv = i_nh // HV, i_nh % HV
|
||||
i_h = i_hv // (HV // H)
|
||||
if IS_VARLEN:
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64)
|
||||
all = T
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_n * T, i_n * T + T
|
||||
all = B * T
|
||||
|
||||
if T == 0:
|
||||
# no tokens to process for this sequence
|
||||
return
|
||||
|
||||
o_k = i_k * BK + tl.arange(0, BK)
|
||||
o_v = i_v * BV + tl.arange(0, BV)
|
||||
|
||||
p_q = q + (bos * H + i_h) * K + o_k
|
||||
p_k = k + (bos * H + i_h) * K + o_k
|
||||
p_v = v + (bos * HV + i_hv) * V + o_v
|
||||
if IS_BETA_HEADWISE:
|
||||
p_beta = beta + (bos * HV + i_hv) * V + o_v
|
||||
else:
|
||||
p_beta = beta + bos * HV + i_hv
|
||||
p_g = g + bos * HV + i_hv
|
||||
p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
|
||||
|
||||
mask_k = o_k < K
|
||||
mask_v = o_v < V
|
||||
mask_h = mask_k[:, None] & mask_v[None, :]
|
||||
|
||||
b_h = tl.zeros([BK, BV], dtype=tl.float32)
|
||||
if USE_INITIAL_STATE:
|
||||
if IS_CONTINUOUS_BATCHING:
|
||||
if IS_SPEC_DECODING:
|
||||
i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
|
||||
else:
|
||||
i_t = 0
|
||||
p_h0 = h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_init_state_token
|
||||
else:
|
||||
p_h0 = h0 + bos * HV * K * V
|
||||
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
|
||||
|
||||
for i_t in range(0, T):
|
||||
b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
|
||||
b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
|
||||
b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
|
||||
b_g = tl.load(p_g).to(tl.float32)
|
||||
|
||||
if USE_QK_L2NORM_IN_KERNEL:
|
||||
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
|
||||
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
|
||||
b_q = b_q * scale
|
||||
# [BK, BV]
|
||||
b_h *= exp(b_g)
|
||||
# [BV]
|
||||
b_v -= tl.sum(b_h * b_k[:, None], 0)
|
||||
if IS_BETA_HEADWISE:
|
||||
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
|
||||
else:
|
||||
b_beta = tl.load(p_beta).to(tl.float32)
|
||||
b_v *= b_beta
|
||||
# [BK, BV]
|
||||
b_h += b_k[:, None] * b_v[None, :]
|
||||
# [BV]
|
||||
b_o = tl.sum(b_h * b_q[:, None], 0)
|
||||
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
|
||||
|
||||
# keep the states for multi-query tokens
|
||||
if INPLACE_FINAL_STATE:
|
||||
p_ht = ht + tl.load(ssm_state_indices + i_n * stride_indices_seq +
|
||||
i_t).to(tl.int64) * stride_final_state_token
|
||||
else:
|
||||
p_ht = ht + (bos + i_t) * stride_final_state_token
|
||||
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
|
||||
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
|
||||
|
||||
p_q += H * K
|
||||
p_k += H * K
|
||||
p_o += HV * V
|
||||
p_v += HV * V
|
||||
p_g += HV
|
||||
p_beta += HV * (V if IS_BETA_HEADWISE else 1)
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule_fwd(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
num_warps = 1
|
||||
|
||||
o = q.new_empty(NK, *v.shape)
|
||||
if inplace_final_state:
|
||||
final_state = initial_state
|
||||
else:
|
||||
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
|
||||
|
||||
stride_init_state_token = initial_state.stride(0)
|
||||
stride_final_state_token = final_state.stride(0)
|
||||
|
||||
if ssm_state_indices is None:
|
||||
stride_indices_seq, stride_indices_tok = 1, 1
|
||||
elif ssm_state_indices.ndim == 1:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
|
||||
else:
|
||||
stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
|
||||
|
||||
grid = (NK, NV, N * HV)
|
||||
fused_recurrent_gated_delta_rule_fwd_kernel[grid](
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
g=g,
|
||||
beta=beta,
|
||||
o=o,
|
||||
h0=initial_state,
|
||||
ht=final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
scale=scale,
|
||||
N=N,
|
||||
T=T,
|
||||
B=B,
|
||||
H=H,
|
||||
HV=HV,
|
||||
K=K,
|
||||
V=V,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
stride_init_state_token=stride_init_state_token,
|
||||
stride_final_state_token=stride_final_state_token,
|
||||
stride_indices_seq=stride_indices_seq,
|
||||
stride_indices_tok=stride_indices_tok,
|
||||
IS_BETA_HEADWISE=beta.ndim == v.ndim,
|
||||
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
|
||||
INPLACE_FINAL_STATE=inplace_final_state,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
)
|
||||
o = o.squeeze(0)
|
||||
return o, final_state
|
||||
|
||||
|
||||
class FusedRecurrentFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
scale: float,
|
||||
initial_state: torch.Tensor,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False):
|
||||
o, final_state = fused_recurrent_gated_delta_rule_fwd(
|
||||
q=q.contiguous(),
|
||||
k=k.contiguous(),
|
||||
v=v.contiguous(),
|
||||
g=g.contiguous(),
|
||||
beta=beta.contiguous(),
|
||||
scale=scale,
|
||||
initial_state=initial_state,
|
||||
inplace_final_state=inplace_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
ssm_state_indices=ssm_state_indices,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
return o, final_state
|
||||
|
||||
|
||||
def fused_recurrent_gated_delta_rule(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
g: torch.Tensor,
|
||||
beta: torch.Tensor = None,
|
||||
scale: float = None,
|
||||
initial_state: torch.Tensor = None,
|
||||
inplace_final_state: bool = True,
|
||||
cu_seqlens: Optional[torch.LongTensor] = None,
|
||||
ssm_state_indices: Optional[torch.Tensor] = None,
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
Args:
|
||||
q (torch.Tensor):
|
||||
queries of shape `[B, T, H, K]`.
|
||||
k (torch.Tensor):
|
||||
keys of shape `[B, T, H, K]`.
|
||||
v (torch.Tensor):
|
||||
values of shape `[B, T, HV, V]`.
|
||||
GVA is applied if `HV > H`.
|
||||
g (torch.Tensor):
|
||||
g (decays) of shape `[B, T, HV]`.
|
||||
beta (torch.Tensor):
|
||||
betas of shape `[B, T, HV]`.
|
||||
scale (Optional[int]):
|
||||
Scale factor for the RetNet attention scores.
|
||||
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
|
||||
initial_state (Optional[torch.Tensor]):
|
||||
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
|
||||
For equal-length input sequences, `N` equals the batch size `B`.
|
||||
Default: `None`.
|
||||
inplace_final_state: bool:
|
||||
Whether to store the final state in-place to save memory.
|
||||
Default: `True`.
|
||||
cu_seqlens (torch.LongTensor):
|
||||
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
|
||||
consistent with the FlashAttention API.
|
||||
ssm_state_indices (Optional[torch.Tensor]):
|
||||
Indices to map the input sequences to the initial/final states.
|
||||
num_accepted_tokens (Optional[torch.Tensor]):
|
||||
Number of accepted tokens for each sequence during decoding.
|
||||
|
||||
Returns:
|
||||
o (torch.Tensor):
|
||||
Outputs of shape `[B, T, HV, V]`.
|
||||
final_state (torch.Tensor):
|
||||
Final state of shape `[N, HV, K, V]`.
|
||||
|
||||
Examples::
|
||||
>>> import torch
|
||||
>>> import torch.nn.functional as F
|
||||
>>> from einops import rearrange
|
||||
>>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
|
||||
# inputs with equal lengths
|
||||
>>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
|
||||
>>> q = torch.randn(B, T, H, K, device='cuda')
|
||||
>>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
|
||||
>>> v = torch.randn(B, T, HV, V, device='cuda')
|
||||
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
|
||||
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
|
||||
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
|
||||
>>> o, ht = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
)
|
||||
# for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
|
||||
>>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
|
||||
# for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
|
||||
>>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
|
||||
>>> o_var, ht_var = fused_gated_recurrent_delta_rule(
|
||||
q, k, v, g, beta,
|
||||
initial_state=h0,
|
||||
cu_seqlens=cu_seqlens
|
||||
)
|
||||
"""
|
||||
if cu_seqlens is not None and q.shape[0] != 1:
|
||||
raise ValueError(
|
||||
f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
|
||||
f"Please flatten variable-length inputs before processing.")
|
||||
if scale is None:
|
||||
scale = k.shape[-1]**-0.5
|
||||
else:
|
||||
assert scale > 0, "scale must be positive"
|
||||
if beta is None:
|
||||
beta = torch.ones_like(q[..., 0])
|
||||
o, final_state = FusedRecurrentFunction.apply(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
g,
|
||||
beta,
|
||||
scale,
|
||||
initial_state,
|
||||
inplace_final_state,
|
||||
cu_seqlens,
|
||||
ssm_state_indices,
|
||||
num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
39
vllm/model_executor/layers/fla/ops/index.py
Normal file
39
vllm/model_executor/layers/fla/ops/index.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .utils import tensor_cache
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
|
||||
return cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_indices(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
indices = torch.cat([
|
||||
torch.arange(n)
|
||||
for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
|
||||
])
|
||||
return torch.stack([indices.eq(0).cumsum(0) - 1, indices],
|
||||
1).to(cu_seqlens)
|
||||
|
||||
|
||||
@tensor_cache
|
||||
def prepare_chunk_offsets(cu_seqlens: torch.LongTensor,
|
||||
chunk_size: int) -> torch.LongTensor:
|
||||
return torch.cat([
|
||||
cu_seqlens.new_tensor([0]),
|
||||
triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
|
||||
]).cumsum(-1)
|
||||
143
vllm/model_executor/layers/fla/ops/l2norm.py
Normal file
143
vllm/model_executor/layers/fla/ops/l2norm.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
BT_LIST = [8, 16, 32, 64, 128]
|
||||
|
||||
USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16, 32]
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel1(
|
||||
x,
|
||||
y,
|
||||
D,
|
||||
BD: tl.constexpr,
|
||||
eps,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
x += i_t * D
|
||||
y += i_t * D
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BD)
|
||||
mask = cols < D
|
||||
b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=0)
|
||||
b_rstd = 1 / tl.sqrt(b_var + eps)
|
||||
# tl.store(Rstd + i_t, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
b_y = b_x * b_rstd
|
||||
tl.store(y + cols, b_y, mask=mask)
|
||||
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config({'BT': BT}, num_warps=num_warps)
|
||||
for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST
|
||||
],
|
||||
key=['D'])
|
||||
@triton.jit(do_not_specialize=["NB"])
|
||||
def l2norm_fwd_kernel(
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB,
|
||||
T,
|
||||
D: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
BD: tl.constexpr,
|
||||
):
|
||||
i_t = tl.program_id(0)
|
||||
p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_var = tl.sum(b_x * b_x, axis=1)
|
||||
b_y = b_x / tl.sqrt(b_var + eps)[:, None]
|
||||
p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
|
||||
tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * MBLOCK
|
||||
row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
|
||||
xmask = row_idx < M
|
||||
rindex = tl.arange(0, N)[None, :]
|
||||
xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
|
||||
square = tl.broadcast_to(xs * xs, [MBLOCK, N])
|
||||
square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
|
||||
rsqrt = tl.rsqrt(square_sum + eps)
|
||||
tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
|
||||
|
||||
|
||||
def l2norm_fwd(x: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
output_dtype: Optional[torch.dtype] = None):
|
||||
x_shape_og = x.shape
|
||||
x = x.view(-1, x.shape[-1])
|
||||
# allocate output
|
||||
if output_dtype is None:
|
||||
y = torch.empty_like(x)
|
||||
else:
|
||||
y = torch.empty_like(x, dtype=output_dtype)
|
||||
assert y.stride(-1) == 1
|
||||
T, D = x.shape[0], x.shape[-1]
|
||||
# rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
|
||||
if D > BD:
|
||||
raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
|
||||
|
||||
if not USE_DEFAULT_FLA_NORM:
|
||||
MBLOCK = 32
|
||||
# M, N = x.shape
|
||||
l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK), )](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
T,
|
||||
D,
|
||||
MBLOCK,
|
||||
)
|
||||
else:
|
||||
if D <= 512:
|
||||
NB = triton.cdiv(T, 2048)
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(T, meta['BT']), )
|
||||
|
||||
l2norm_fwd_kernel[grid](
|
||||
x,
|
||||
y,
|
||||
eps,
|
||||
NB=NB,
|
||||
T=T,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
else:
|
||||
l2norm_fwd_kernel1[(T, )](
|
||||
x,
|
||||
y,
|
||||
eps=eps,
|
||||
D=D,
|
||||
BD=BD,
|
||||
)
|
||||
|
||||
return y.view(x_shape_og)
|
||||
337
vllm/model_executor/layers/fla/ops/layernorm_guard.py
Normal file
337
vllm/model_executor/layers/fla/ops/layernorm_guard.py
Normal file
@@ -0,0 +1,337 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Tri Dao
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
|
||||
# ruff: noqa: E501
|
||||
# Based on the Triton LayerNorm tutorial: https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html
|
||||
# For the backward pass, we keep weight_grad and bias_grad in registers and accumulate.
|
||||
# This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling.
|
||||
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
def rms_norm_ref(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
upcast=True):
|
||||
dtype = x.dtype
|
||||
weight = weight.float()
|
||||
bias = bias.float() if bias is not None else None
|
||||
if upcast:
|
||||
x = x.float()
|
||||
z = z.float() if z is not None else z
|
||||
if z is not None and not norm_before_gate:
|
||||
x = x * F.silu(z)
|
||||
if group_size is None:
|
||||
rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
|
||||
out = (x * rstd * weight) + bias if bias is not None else (x * rstd *
|
||||
weight)
|
||||
else:
|
||||
x_group = rearrange(x, "... (g d) -> ... g d", d=group_size)
|
||||
rstd = 1 / torch.sqrt((x_group.square()).mean(dim=-1, keepdim=True) +
|
||||
eps)
|
||||
out = rearrange(x_group * rstd, "... g d -> ... (g d)") * weight
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
if z is not None and norm_before_gate:
|
||||
out *= F.silu(z)
|
||||
return out.to(dtype)
|
||||
|
||||
|
||||
@triton.heuristics({
|
||||
"HAS_BIAS": lambda args: args["B"] is not None,
|
||||
"HAS_Z": lambda args: args["Z"] is not None,
|
||||
})
|
||||
@triton.jit
|
||||
def layer_norm_fwd_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row, # how much to increase the pointer when moving by 1 row
|
||||
stride_y_row,
|
||||
stride_z_row,
|
||||
M, # number of rows in X
|
||||
N, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def layer_norm_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
eps: float,
|
||||
z: torch.Tensor = None,
|
||||
out: torch.Tensor = None,
|
||||
group_size: int = None,
|
||||
norm_before_gate: bool = True,
|
||||
is_rms_norm: bool = False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
layer_norm_fwd_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@input_guard
|
||||
@staticmethod
|
||||
def forward(ctx,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
ctx.save_for_backward(x, weight, bias, mean, rstd, z)
|
||||
ctx.x_shape_og = x_shape_og
|
||||
ctx.eps = eps
|
||||
ctx.group_size = group_size
|
||||
ctx.norm_before_gate = norm_before_gate
|
||||
ctx.is_rms_norm = is_rms_norm
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, is_rms_norm)
|
||||
|
||||
|
||||
def rmsnorm_fn(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
return LayerNormFn.apply(x, weight, bias, z, eps, group_size,
|
||||
norm_before_gate, True)
|
||||
|
||||
|
||||
class LayerNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
torch.nn.init.zeros_(self.bias)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return layernorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
|
||||
|
||||
class RMSNormGated(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
eps: float = 1e-5,
|
||||
group_size: Optional[int] = None,
|
||||
norm_before_gate: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
|
||||
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
|
||||
self.register_parameter("bias", None)
|
||||
self.group_size = group_size
|
||||
self.norm_before_gate = norm_before_gate
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
torch.nn.init.ones_(self.weight)
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))
|
||||
"""
|
||||
return rmsnorm_fn(x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
z=z,
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate)
|
||||
39
vllm/model_executor/layers/fla/ops/op.py
Normal file
39
vllm/model_executor/layers/fla/ops/op.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
import os
|
||||
|
||||
from vllm.triton_utils import tl, tldevice, triton
|
||||
|
||||
if os.environ.get('FLA_USE_FAST_OPS', '0') == '1':
|
||||
div = tldevice.fast_dividef
|
||||
exp = tldevice.fast_expf
|
||||
log = tldevice.fast_logf
|
||||
log2 = tldevice.fast_log2f
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def div_normal(x, y):
|
||||
return x / y
|
||||
|
||||
div = div_normal
|
||||
exp = tl.exp
|
||||
log = tl.log
|
||||
log2 = tl.log2
|
||||
|
||||
|
||||
if not hasattr(tl, 'gather'):
|
||||
|
||||
@triton.jit
|
||||
def gather(src, index, axis, _builder=None):
|
||||
# This is a fallback implementation when tl.gather is not supported
|
||||
# In order to pass triton compiler, there is no actual gather operation
|
||||
return src
|
||||
else:
|
||||
gather = tl.gather
|
||||
365
vllm/model_executor/layers/fla/ops/solve_tril.py
Normal file
365
vllm/model_executor/layers/fla/ops/solve_tril.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .utils import input_guard
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=['BT'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def solve_tril_16x16_kernel(
|
||||
A,
|
||||
Ad,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
T,
|
||||
H: tl.constexpr,
|
||||
BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A = A + (bos * H + i_h) * BT
|
||||
Ad = Ad + (bos * H + i_h) * 16
|
||||
|
||||
offset = (i_t * 16) % BT
|
||||
p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16),
|
||||
(1, 0))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
|
||||
b_A = -tl.where(
|
||||
tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0)
|
||||
|
||||
o_i = tl.arange(0, 16)
|
||||
for i in range(1, min(16, T - i_t * 16)):
|
||||
b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
|
||||
b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
|
||||
mask = o_i == i
|
||||
b_A = tl.where(mask[:, None], b_a, b_A)
|
||||
b_A += o_i[:, None] == o_i[None, :]
|
||||
tl.store(p_Ai,
|
||||
b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=['H', 'BT', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def merge_16x16_to_32x32_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 32
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 32
|
||||
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'),
|
||||
Ai_11,
|
||||
input_precision='ieee')
|
||||
tl.store(p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4, 8] for num_stages in [2, 3, 4, 5]
|
||||
],
|
||||
key=['H', 'BT', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def merge_16x16_to_64x64_inverse_kernel(A, Ad, Ai, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, BT: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
|
||||
A += (bos * H + i_h) * 64
|
||||
Ad += (bos * H + i_h) * 16
|
||||
Ai += (bos * H + i_h) * 64
|
||||
|
||||
p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0),
|
||||
(16, 16), (1, 0))
|
||||
|
||||
A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
|
||||
A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32)
|
||||
Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32)
|
||||
|
||||
Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision='ieee'),
|
||||
Ai_11,
|
||||
input_precision='ieee')
|
||||
Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision='ieee'),
|
||||
Ai_22,
|
||||
input_precision='ieee')
|
||||
Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision='ieee'),
|
||||
Ai_33,
|
||||
input_precision='ieee')
|
||||
|
||||
Ai_31 = -tl.dot(Ai_33,
|
||||
tl.dot(A_31, Ai_11, input_precision='ieee') +
|
||||
tl.dot(A_32, Ai_21, input_precision='ieee'),
|
||||
input_precision='ieee')
|
||||
Ai_42 = -tl.dot(Ai_44,
|
||||
tl.dot(A_42, Ai_22, input_precision='ieee') +
|
||||
tl.dot(A_43, Ai_32, input_precision='ieee'),
|
||||
input_precision='ieee')
|
||||
Ai_41 = -tl.dot(Ai_44,
|
||||
tl.dot(A_41, Ai_11, input_precision='ieee') +
|
||||
tl.dot(A_42, Ai_21, input_precision='ieee') +
|
||||
tl.dot(A_43, Ai_31, input_precision='ieee'),
|
||||
input_precision='ieee')
|
||||
|
||||
p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32),
|
||||
(16, 16), (1, 0))
|
||||
tl.store(p_Ai_11,
|
||||
Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_22,
|
||||
Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_33,
|
||||
Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_44,
|
||||
Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_21,
|
||||
Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_31,
|
||||
Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_32,
|
||||
Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_41,
|
||||
Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_42,
|
||||
Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_43,
|
||||
Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
fill_zeros = tl.zeros((16, 16), dtype=tl.float32)
|
||||
p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48),
|
||||
(16, 16), (1, 0))
|
||||
p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48),
|
||||
(16, 16), (1, 0))
|
||||
tl.store(p_Ai_12,
|
||||
fill_zeros.to(p_Ai_12.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_13,
|
||||
fill_zeros.to(p_Ai_13.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_14,
|
||||
fill_zeros.to(p_Ai_14.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_23,
|
||||
fill_zeros.to(p_Ai_23.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_24,
|
||||
fill_zeros.to(p_Ai_24.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
tl.store(p_Ai_34,
|
||||
fill_zeros.to(p_Ai_34.dtype.element_ty,
|
||||
fp_downcast_rounding="rtne"),
|
||||
boundary_check=(0, 1))
|
||||
|
||||
|
||||
@input_guard
|
||||
def solve_tril(A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
output_dtype: torch.dtype = torch.float) -> torch.Tensor:
|
||||
"""
|
||||
Compute the inverse of the lower triangular matrix
|
||||
A should be strictly lower triangular, i.e., A.triu() == 0.
|
||||
|
||||
Args:
|
||||
A (torch.Tensor):
|
||||
[B, T, H, K]
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`
|
||||
|
||||
Returns:
|
||||
(I + A)^-1 with the same shape as A
|
||||
"""
|
||||
assert A.shape[-1] in [16, 32, 64]
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
Ad = torch.empty(B,
|
||||
T,
|
||||
H,
|
||||
16,
|
||||
device=A.device,
|
||||
dtype=torch.float if BT != 16 else output_dtype)
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, 16) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16)
|
||||
solve_tril_16x16_kernel[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
)
|
||||
if BT == 16:
|
||||
return Ad
|
||||
|
||||
Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype)
|
||||
merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
merge_fn[NT, B * H](
|
||||
A=A,
|
||||
Ad=Ad,
|
||||
Ai=Ai,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
BT=BT,
|
||||
)
|
||||
return Ai
|
||||
180
vllm/model_executor/layers/fla/ops/utils.py
Normal file
180
vllm/model_executor/layers/fla/ops/utils.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
# ruff: noqa: E501
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
|
||||
FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
|
||||
FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
|
||||
|
||||
SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
|
||||
|
||||
|
||||
def tensor_cache(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator that caches the most recent results of a function with tensor inputs.
|
||||
|
||||
This decorator will store the output of the decorated function for the most recent set of input tensors.
|
||||
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
|
||||
|
||||
Args:
|
||||
fn (Callable[..., torch.Tensor]):
|
||||
The function to be decorated. It should take tensor inputs and return tensor outputs.
|
||||
|
||||
Returns:
|
||||
Callable[..., torch.Tensor]:
|
||||
A wrapped version of the input function with single-entry caching.
|
||||
"""
|
||||
|
||||
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
|
||||
cache_size = 4
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal cache_entries, cache_size
|
||||
for i, entry in enumerate(cache_entries):
|
||||
last_args, last_kwargs, last_result = entry
|
||||
if len(args) == len(last_args) and len(kwargs) == len(last_kwargs) \
|
||||
and all(a is b for a, b in zip(args, last_args)) \
|
||||
and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()):
|
||||
cache_entries = cache_entries[:i] + cache_entries[i + 1:] + [
|
||||
(args, kwargs, last_result)
|
||||
]
|
||||
return last_result
|
||||
|
||||
result = fn(*args, **kwargs)
|
||||
|
||||
if len(cache_entries) >= cache_size:
|
||||
cache_entries = cache_entries[1:]
|
||||
cache_entries.append((args, kwargs, result))
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def input_guard(
|
||||
fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
|
||||
"""
|
||||
A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
|
||||
"""
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
contiguous_args = (i if not isinstance(i, torch.Tensor) else
|
||||
i.contiguous() for i in args)
|
||||
contiguous_kwargs = {
|
||||
k: (v if not isinstance(v, torch.Tensor) else v.contiguous())
|
||||
for k, v in kwargs.items()
|
||||
}
|
||||
|
||||
tensor = None
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
tensor = arg
|
||||
break
|
||||
if tensor is None:
|
||||
for value in kwargs.values():
|
||||
if isinstance(value, torch.Tensor):
|
||||
tensor = value
|
||||
break
|
||||
|
||||
if tensor is not None:
|
||||
ctx = torch.cuda.device(tensor.device.index)
|
||||
else:
|
||||
ctx = contextlib.nullcontext()
|
||||
|
||||
with ctx:
|
||||
return fn(*contiguous_args, **contiguous_kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@functools.cache
|
||||
def get_available_device() -> str:
|
||||
try:
|
||||
return triton.runtime.driver.active.get_current_target().backend
|
||||
except BaseException:
|
||||
return 'cpu'
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _check_platform() -> Literal['nvidia', 'amd', 'intel', 'musa']:
|
||||
device = get_available_device()
|
||||
mapping = {
|
||||
"cuda": "nvidia",
|
||||
"hip": "amd",
|
||||
"xpu": "intel",
|
||||
}
|
||||
# return the mapped value, or the original if not found
|
||||
return mapping.get(device, device)
|
||||
|
||||
|
||||
# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
|
||||
# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
|
||||
# Therefore, we need to check the triton backend to determine the actual GPU vendor.
|
||||
device = get_available_device() if get_available_device() != 'hip' else 'cuda'
|
||||
device_torch_lib = getattr(torch, device)
|
||||
device_platform = _check_platform()
|
||||
|
||||
is_amd = (device_platform == 'amd')
|
||||
is_intel = (device_platform == 'intel')
|
||||
is_nvidia = (device_platform == 'nvidia')
|
||||
is_intel_alchemist = (is_intel
|
||||
and 'Intel(R) Arc(TM) A' in torch.xpu.get_device_name(0))
|
||||
is_nvidia_hopper = (is_nvidia
|
||||
and ('NVIDIA H' in torch.cuda.get_device_name(0)
|
||||
or torch.cuda.get_device_capability()[0] >= 9))
|
||||
use_cuda_graph = (is_nvidia
|
||||
and os.environ.get('FLA_USE_CUDA_GRAPH', '0') == '1')
|
||||
|
||||
|
||||
def get_all_max_shared_mem():
|
||||
try:
|
||||
return [
|
||||
triton.runtime.driver.active.utils.get_device_properties(i)
|
||||
['max_shared_mem'] for i in range(device_torch_lib.device_count())
|
||||
]
|
||||
except BaseException:
|
||||
return [-1]
|
||||
|
||||
|
||||
class Backend(Enum):
|
||||
ADA = 101376 # RTX 4090
|
||||
AMPERE = 166912 # A100
|
||||
HOPPER = 232448 # H100
|
||||
DEFAULT = 102400 # Default
|
||||
|
||||
@classmethod
|
||||
def get_shared_memory(cls, arch: str) -> int:
|
||||
try:
|
||||
return cls[arch.upper()].value
|
||||
except KeyError:
|
||||
return cls.DEFAULT.value
|
||||
|
||||
|
||||
@functools.cache
|
||||
def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
|
||||
try:
|
||||
device_shared_mem_list = get_all_max_shared_mem()
|
||||
max_shared_memory = device_shared_mem_list[tensor_idx]
|
||||
return max_shared_memory >= Backend.get_shared_memory(arch)
|
||||
except Exception:
|
||||
return False
|
||||
114
vllm/model_executor/layers/fla/ops/wy_fast.py
Normal file
114
vllm/model_executor/layers/fla/ops/wy_fast.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
|
||||
#
|
||||
# This file contains code copied from the flash-linear-attention project.
|
||||
# The original source code was licensed under the MIT license and included
|
||||
# the following copyright notice:
|
||||
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
|
||||
|
||||
# ruff: noqa: E501
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
|
||||
|
||||
@triton.heuristics({'IS_VARLEN': lambda args: args['cu_seqlens'] is not None})
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({}, num_warps=num_warps, num_stages=num_stages)
|
||||
for num_warps in [2, 4, 8] for num_stages in [2, 3, 4]
|
||||
],
|
||||
key=['H', 'K', 'V', 'BT', 'BK', 'BV', 'IS_VARLEN'],
|
||||
)
|
||||
@triton.jit(do_not_specialize=['T'])
|
||||
def recompute_w_u_fwd_kernel(k, v, beta, w, u, A, g, cu_seqlens, chunk_indices,
|
||||
T, H: tl.constexpr, Hg: tl.constexpr,
|
||||
K: tl.constexpr, V: tl.constexpr,
|
||||
BT: tl.constexpr, BK: tl.constexpr,
|
||||
BV: tl.constexpr, IS_VARLEN: tl.constexpr):
|
||||
i_t, i_bh = tl.program_id(0), tl.program_id(1)
|
||||
i_b, i_h = i_bh // H, i_bh % H
|
||||
if IS_VARLEN:
|
||||
i_n, i_t = tl.load(chunk_indices + i_t * 2).to(
|
||||
tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32)
|
||||
bos, eos = tl.load(cu_seqlens + i_n).to(
|
||||
tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32)
|
||||
T = eos - bos
|
||||
else:
|
||||
bos, eos = i_b * T, i_b * T + T
|
||||
p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T, ), (H, ),
|
||||
(i_t * BT, ), (BT, ), (0, ))
|
||||
p_g = tl.make_block_ptr(g + (bos * H + i_h), (T, ), (H, ), (i_t * BT, ),
|
||||
(BT, ), (0, ))
|
||||
p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1),
|
||||
(i_t * BT, 0), (BT, BT), (1, 0))
|
||||
b_beta = tl.load(p_beta, boundary_check=(0, ))
|
||||
b_A = tl.load(p_A, boundary_check=(0, 1))
|
||||
b_g = tl.exp(tl.load(p_g, boundary_check=(0, )))
|
||||
|
||||
for i_v in range(tl.cdiv(V, BV)):
|
||||
p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1),
|
||||
(i_t * BT, i_v * BV), (BT, BV), (1, 0))
|
||||
b_v = tl.load(p_v, boundary_check=(0, 1))
|
||||
b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
|
||||
b_u = tl.dot(b_A, b_vb, allow_tf32=False)
|
||||
tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
for i_k in range(tl.cdiv(K, BK)):
|
||||
p_k = tl.make_block_ptr(k + (bos * Hg + i_h // (H // Hg)) * K, (T, K),
|
||||
(Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK),
|
||||
(1, 0))
|
||||
p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1),
|
||||
(i_t * BT, i_k * BK), (BT, BK), (1, 0))
|
||||
b_k = tl.load(p_k, boundary_check=(0, 1))
|
||||
b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
|
||||
b_w = tl.dot(b_A, b_kb)
|
||||
tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
|
||||
|
||||
|
||||
def recompute_w_u_fwd(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
beta: torch.Tensor,
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: Optional[torch.LongTensor],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = prepare_chunk_indices(
|
||||
cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
u = torch.empty_like(v)
|
||||
w = k.new_empty(B, T, H, K)
|
||||
recompute_w_u_fwd_kernel[(NT, B * H)](
|
||||
k=k,
|
||||
v=v,
|
||||
beta=beta,
|
||||
w=w,
|
||||
u=u,
|
||||
A=A,
|
||||
g=g_cumsum,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
T=T,
|
||||
H=H,
|
||||
Hg=Hg,
|
||||
K=K,
|
||||
V=V,
|
||||
BT=BT,
|
||||
BK=BK,
|
||||
BV=BV,
|
||||
)
|
||||
return w, u
|
||||
89
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
89
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Optional
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.modular_kernel import (
|
||||
FusedMoEActivationFormat, FusedMoEPermuteExpertsUnpermute,
|
||||
FusedMoEPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.utils import activation_without_mul
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
_config: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_config(config):
|
||||
global _config
|
||||
old_config = _config
|
||||
_config = config
|
||||
yield
|
||||
_config = old_config
|
||||
|
||||
|
||||
def get_config() -> Optional[dict[str, Any]]:
|
||||
return _config
|
||||
|
||||
|
||||
__all__ = [
|
||||
"FusedMoE",
|
||||
"FusedMoEConfig",
|
||||
"FusedMoEMethodBase",
|
||||
"FusedMoeWeightScaleSupported",
|
||||
"FusedMoEPermuteExpertsUnpermute",
|
||||
"FusedMoEActivationFormat",
|
||||
"FusedMoEPrepareAndFinalize",
|
||||
"activation_without_mul",
|
||||
"override_config",
|
||||
"get_config",
|
||||
]
|
||||
|
||||
if HAS_TRITON:
|
||||
# import to register the custom ops
|
||||
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8, CutlassExpertsFp8, cutlass_moe_fp4,
|
||||
cutlass_moe_fp8)
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
|
||||
DeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
TritonExperts, fused_experts, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
|
||||
__all__ += [
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
"cutlass_moe_fp8",
|
||||
"cutlass_moe_fp4",
|
||||
"CutlassExpertsFp8",
|
||||
"CutlassBatchedExpertsFp8",
|
||||
"TritonExperts",
|
||||
"BatchedTritonExperts",
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
else:
|
||||
# Some model classes directly use the custom ops. Add placeholders
|
||||
# to avoid import errors.
|
||||
def _raise_exception(method: str):
|
||||
raise NotImplementedError(
|
||||
f"{method} is not implemented as lack of triton.")
|
||||
|
||||
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
322
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Normal file
322
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
Normal file
@@ -0,0 +1,322 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from math import log2
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
||||
is_deep_gemm_e8m0_used)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _silu_mul_fp8_quant_deep_gemm(
|
||||
# Pointers ------------------------------------------------------------
|
||||
input_ptr, # 16-bit activations (E, T, 2*H)
|
||||
y_q_ptr, # fp8 quantized activations (E, T, H)
|
||||
y_s_ptr, # 16-bit scales (E, T, G)
|
||||
counts_ptr, # int32 num tokens per expert (E)
|
||||
# Sizes ---------------------------------------------------------------
|
||||
H: tl.constexpr, # hidden dimension (per output)
|
||||
GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
|
||||
# Strides for input (elements) ---------------------------------------
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
# Strides for y_q (elements) -----------------------------------------
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
# Strides for y_s (elements) -----------------------------------------
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
# Stride for counts (elements)
|
||||
stride_counts_e,
|
||||
# Numeric params ------------------------------------------------------
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGES: tl.constexpr,
|
||||
):
|
||||
G = H // GROUP_SIZE
|
||||
|
||||
# map program id -> (e, g)
|
||||
pid = tl.program_id(0)
|
||||
e = pid // G
|
||||
g = pid % G
|
||||
|
||||
e = e.to(tl.int64)
|
||||
g = g.to(tl.int64)
|
||||
|
||||
# number of valid tokens for this expert
|
||||
n_tokens = tl.load(counts_ptr + e * stride_counts_e).to(tl.int64)
|
||||
|
||||
cols = tl.arange(0, BLOCK).to(tl.int64)
|
||||
mask = cols < BLOCK
|
||||
|
||||
base_input_offset = e * stride_i_e + g * GROUP_SIZE * stride_i_h
|
||||
base_gate_offset = base_input_offset + cols * stride_i_h
|
||||
base_up_offset = base_input_offset + H * stride_i_h + cols * stride_i_h
|
||||
base_yq_offset = (e * stride_yq_e + g * GROUP_SIZE * stride_yq_h +
|
||||
cols * stride_yq_h)
|
||||
base_ys_offset = e * stride_ys_e + g * stride_ys_g
|
||||
|
||||
for t in tl.range(0, n_tokens, num_stages=NUM_STAGES):
|
||||
gate = tl.load(input_ptr + base_gate_offset + t * stride_i_t,
|
||||
mask=mask,
|
||||
other=0.0).to(tl.float32)
|
||||
up = tl.load(input_ptr + base_up_offset + t * stride_i_t,
|
||||
mask=mask,
|
||||
other=0.0)
|
||||
|
||||
gate = gate * (1.0 / (1.0 + tl.exp(-gate)))
|
||||
y = gate * up
|
||||
|
||||
y_s = tl.maximum(tl.max(tl.abs(y)), eps) / fp8_max
|
||||
if use_ue8m0:
|
||||
y_s = tl.exp2(tl.ceil(tl.log2(y_s)))
|
||||
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + t * stride_yq_t, y_q, mask=mask)
|
||||
tl.store(y_s_ptr + base_ys_offset + t * stride_ys_t, y_s)
|
||||
|
||||
|
||||
def silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
y: torch.Tensor, # (E, T, 2*H)
|
||||
tokens_per_expert: torch.Tensor, # (E,) number of valid tokens per expert
|
||||
num_parallel_tokens=16,
|
||||
group_size: int = 128,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Quantize silu(y[..., :H]) * y[..., H:] to FP8 with group per-token scales
|
||||
y has shape (E, T, 2*H). The first half of the last dimension is
|
||||
silu-activated, multiplied by the second half, then quantized into FP8.
|
||||
Returns `(y_q, y_s)` where
|
||||
* `y_q`: FP8 tensor, shape (E, T, H), same layout as y[..., :H]
|
||||
* `y_s`: FP32 tensor, shape (E, T, H // group_size), strides (T*G, 1, T)
|
||||
"""
|
||||
assert y.ndim == 3, "y must be (E, T, 2*H)"
|
||||
E, T, H2 = y.shape
|
||||
assert H2 % 2 == 0, "last dim of y must be even (2*H)"
|
||||
H = H2 // 2
|
||||
G = (H + group_size - 1) // group_size
|
||||
assert H % 8 == 0, "H must be divisible by 8"
|
||||
assert group_size == 128, "H must be divisible by 8"
|
||||
assert tokens_per_expert.ndim == 1 and tokens_per_expert.shape[0] == E
|
||||
|
||||
tokens_per_expert = tokens_per_expert.to(device=y.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
fp8_dtype = torch.float8_e4m3fn
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided((E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device)
|
||||
|
||||
use_ue8m0 = is_deep_gemm_e8m0_used()
|
||||
|
||||
if E <= 16:
|
||||
max_empirical_parallelism = 64
|
||||
elif E <= 32:
|
||||
max_empirical_parallelism = 16
|
||||
else:
|
||||
max_empirical_parallelism = 4
|
||||
|
||||
# We never want to launch more than Tx number of threads
|
||||
# This computes the clip.
|
||||
num_parallel_tokens = max(
|
||||
1,
|
||||
min(max_empirical_parallelism, 2**int(log2(min(num_parallel_tokens,
|
||||
T)))))
|
||||
cuda_arch = current_platform.get_device_capability(
|
||||
device_id=y.device.index).to_int()
|
||||
|
||||
if cuda_arch >= 80:
|
||||
torch.ops._C.silu_mul_fp8_quant_deep_gemm_cuda(y, tokens_per_expert,
|
||||
y_q, y_s, group_size,
|
||||
use_ue8m0,
|
||||
num_parallel_tokens)
|
||||
else:
|
||||
# Default to triton if not on cuda or if arch is too old
|
||||
y_q = torch.empty((E, T, H), dtype=fp8_dtype, device=y.device)
|
||||
|
||||
stride_cnt_e = tokens_per_expert.stride()[0]
|
||||
|
||||
# Static grid over experts and H-groups.
|
||||
# A loop inside the kernel handles the token dim
|
||||
grid = (E * G, )
|
||||
# strides (elements)
|
||||
stride_i_e, stride_i_t, stride_i_h = y.stride()
|
||||
stride_yq_e, stride_yq_t, stride_yq_h = y_q.stride()
|
||||
|
||||
# desired scale strides (elements): (T*G, 1, T)
|
||||
stride_ys_e = T * G
|
||||
stride_ys_t = 1
|
||||
stride_ys_g = T
|
||||
y_s = torch.empty_strided(
|
||||
(E, T, G),
|
||||
(stride_ys_e, stride_ys_t, stride_ys_g),
|
||||
dtype=torch.float32,
|
||||
device=y.device,
|
||||
)
|
||||
f_info = torch.finfo(fp8_dtype)
|
||||
fp8_max = f_info.max
|
||||
fp8_min = f_info.min
|
||||
eps: float = 1e-10
|
||||
_silu_mul_fp8_quant_deep_gemm[grid](
|
||||
y,
|
||||
y_q,
|
||||
y_s,
|
||||
tokens_per_expert,
|
||||
H,
|
||||
group_size,
|
||||
stride_i_e,
|
||||
stride_i_t,
|
||||
stride_i_h,
|
||||
stride_yq_e,
|
||||
stride_yq_t,
|
||||
stride_yq_h,
|
||||
stride_ys_e,
|
||||
stride_ys_t,
|
||||
stride_ys_g,
|
||||
stride_cnt_e,
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
is_deep_gemm_e8m0_used(),
|
||||
BLOCK=group_size,
|
||||
NUM_STAGES=4,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
return y_q, y_s
|
||||
|
||||
|
||||
class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
"""
|
||||
max_num_tokens: Maximum number of tokens from a DP Rank
|
||||
num_dispatchers: The number of DP dispatchers.
|
||||
quant_config: Quantization configuration
|
||||
"""
|
||||
super().__init__(quant_config)
|
||||
assert self.block_shape == deep_gemm_block_shape()
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.num_dispatchers = num_dispatchers
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
return (mk.FusedMoEActivationFormat.BatchedExperts,
|
||||
mk.FusedMoEActivationFormat.BatchedExperts)
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
return False
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
# Let PrepareAndFinalize::finalize() decide the impl.
|
||||
return TopKWeightAndReduceDelegate()
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
assert a.dim() == 2
|
||||
# FIXME (varun): We should be able to dispatch only from the leader
|
||||
# DP ranks in the case of TP > 1. At the moment, all the Ranks
|
||||
# end up sending their tokens. This needs to be fixed.
|
||||
num_dispatchers = self.num_dispatchers
|
||||
num_experts = local_num_experts
|
||||
max_num_tokens = a.size(
|
||||
0) if self.max_num_tokens is None else self.max_num_tokens
|
||||
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
|
||||
max(K, N))
|
||||
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output, a.dtype)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert expert_tokens_meta is not None
|
||||
expert_num_tokens = expert_tokens_meta.expert_num_tokens
|
||||
|
||||
assert hidden_states.ndim == 3
|
||||
assert self.block_shape is not None
|
||||
|
||||
a1q = hidden_states
|
||||
_, N, K = w1.size()
|
||||
|
||||
assert w2.size(1) == K
|
||||
|
||||
E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size(
|
||||
hidden_states, w1, w2, topk_ids)
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
|
||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
# may lead to better performance.
|
||||
expected_m = max_num_tokens
|
||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, self.w1_scale),
|
||||
workspace1, expert_num_tokens, expected_m)
|
||||
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm_cuda(
|
||||
workspace1, expert_num_tokens)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, self.w2_scale),
|
||||
output, expert_num_tokens, expected_m)
|
||||
@@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_utils import (
|
||||
deep_gemm_block_shape)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts)
|
||||
|
||||
|
||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = (allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8 and
|
||||
self.block_shape == deep_gemm_block_shape())
|
||||
|
||||
self.batched_deep_gemm_experts = BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
) if self.allow_deep_gemm else None
|
||||
|
||||
assert (self.batched_deep_gemm_experts is not None
|
||||
or self.batched_triton_experts is not None)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.batched_triton_experts is not None:
|
||||
assert (self.batched_deep_gemm_experts is None
|
||||
or self.batched_deep_gemm_experts.activation_formats
|
||||
== self.batched_triton_experts.activation_formats)
|
||||
return self.batched_triton_experts.activation_formats
|
||||
else:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return ((bdge is None or bdge.supports_chunking())
|
||||
and (bte is None or bte.supports_chunking()))
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return ((bdge is None or bdge.supports_expert_map())
|
||||
and (bte is None or bte.supports_expert_map()))
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
|
||||
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
|
||||
is_bdge_war = bdge_war is not None
|
||||
is_bte_war = bte_war is not None
|
||||
|
||||
if is_bdge_war and is_bte_war:
|
||||
assert bdge_war == bte_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}")
|
||||
|
||||
if bdge_war is not None:
|
||||
return bdge_war
|
||||
|
||||
assert bte_war is not None
|
||||
return bte_war
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
a: torch.Tensor,
|
||||
aq: torch.Tensor,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: Optional[mk.ExpertTokensMetadata],
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
|
||||
expert_tokens_metadata)
|
||||
else:
|
||||
assert self.batched_triton_experts is not None
|
||||
return self.batched_triton_experts.workspace_shapes(
|
||||
a, aq, M, N, K, topk, global_num_experts, local_num_experts,
|
||||
expert_tokens_metadata)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
experts = (self.batched_deep_gemm_experts
|
||||
if self.allow_deep_gemm else self.batched_triton_experts)
|
||||
assert experts is not None
|
||||
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
activation, global_num_experts, expert_map, a1q_scale,
|
||||
a2_scale, workspace13, workspace2, expert_tokens_meta,
|
||||
apply_router_weight_on_input)
|
||||
804
vllm/model_executor/layers/fused_moe/config.py
Normal file
804
vllm/model_executor/layers/fused_moe/config.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.utils import cdiv, has_triton_kernels
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if has_triton_kernels():
|
||||
try:
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible.")
|
||||
|
||||
|
||||
def _get_config_dtype_str(
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
use_mxfp4_w4a4: bool = False,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
try_get_optimal_moe_config in fused_moe.py.
|
||||
"""
|
||||
if use_fp8_w8a8:
|
||||
return "fp8_w8a8"
|
||||
elif use_int8_w8a16:
|
||||
return "int8_w8a16"
|
||||
elif use_int4_w4a16:
|
||||
return "int4_w4a16"
|
||||
elif use_mxfp4_w4a4:
|
||||
return "mxfp4_w4a4"
|
||||
elif dtype == torch.float:
|
||||
# avoiding cases where kernel fails when float32 MoE
|
||||
# use fp16/bfloat16 configs
|
||||
return "float32"
|
||||
return None
|
||||
|
||||
|
||||
def _quant_flags_to_group_shape(
|
||||
quant_dtype: Union[torch.dtype, str, None],
|
||||
per_act_token_quant: bool,
|
||||
per_out_ch_quant: bool,
|
||||
block_shape: Optional[list[int]],
|
||||
) -> tuple[Optional[GroupShape], Optional[GroupShape]]:
|
||||
"""
|
||||
Convert MoE quantization flags into more generic GroupShapes.
|
||||
"""
|
||||
a_shape: Optional[GroupShape]
|
||||
w_shape: Optional[GroupShape]
|
||||
if block_shape is not None:
|
||||
assert not per_act_token_quant
|
||||
assert not per_out_ch_quant
|
||||
# TODO(bnell): this is not quite right for activations since first
|
||||
# dim should be 1.
|
||||
a_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||
w_shape = GroupShape(row=block_shape[0], col=block_shape[1])
|
||||
else:
|
||||
w_shape = None
|
||||
a_shape = None if quant_dtype is None else GroupShape.PER_TENSOR
|
||||
|
||||
if per_act_token_quant:
|
||||
a_shape = GroupShape.PER_TOKEN
|
||||
|
||||
if per_out_ch_quant:
|
||||
w_shape = GroupShape.PER_TOKEN
|
||||
|
||||
return a_shape, w_shape
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEQuantDesc:
|
||||
"""
|
||||
A quantization descriptor for fused MoE ops. This class can describe
|
||||
either activations or weights.
|
||||
"""
|
||||
|
||||
# The quantized type of this parameters. None means unquantized or
|
||||
# already quantized.
|
||||
# TODO (bnell): use scalar_type instead of Union.
|
||||
dtype: Union[torch.dtype, str, None] = None
|
||||
|
||||
# A field that describes the quantization group shape, from quant_utils.py.
|
||||
# * (-1, -1) for per-tensor quantization
|
||||
# * (1, -1) for per-row quantization
|
||||
# * (-1, 1) for per-column quantization
|
||||
# * (128, 128) for 128x128 deepseek style block quantization
|
||||
# * (1, 128) for deepseek style activation quantization
|
||||
# (i.e. per-token-per-group)
|
||||
shape: Optional[GroupShape] = None
|
||||
|
||||
# Quantization scales.
|
||||
# TODO(bnell): maybe put PrecisionConfigs in subclass of QuantDesc?
|
||||
scale: Union[torch.Tensor, "PrecisionConfig", None] = None
|
||||
|
||||
# Quantization alphas or gscales, used for nvfp4 types.
|
||||
# TODO(bnell): put some of these in subclasses
|
||||
alpha_or_gscale: Optional[torch.Tensor] = None
|
||||
|
||||
# Zero points for int4/int8 types
|
||||
zp: Optional[torch.Tensor] = None
|
||||
|
||||
# Biases for GPT triton MoE
|
||||
bias: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
# TODO(bnell): have subclasses for specific moe methods?
|
||||
# e.g. for specific arguments bias, precision, etc.
|
||||
@dataclass
|
||||
class FusedMoEQuantConfig:
|
||||
"""
|
||||
The FusedMoEQuantConfig contains all the quantization parameters for
|
||||
a single FusedMoEMethodBase operation. It consists of four
|
||||
FusedMoEQuantDescs, one for each activation and set of weights.
|
||||
|
||||
Each FusedMoEMethodBase must implement a get_fused_moe_quant_config
|
||||
method to construct a FusedMoEQuantConfig for use with that class.
|
||||
|
||||
FusedMoEQuant configs are only used for modular kernels, fused_experts
|
||||
(from fused_moe.py), cutlass_moe_fp[48], rocm_aiter_fused_experts and
|
||||
triton_kernel_moe_forward. Other MoE methods can ignore the
|
||||
FusedMoEQuantConfig (for now) and hardcode it to None.
|
||||
|
||||
There are currently some restrictions on what can be expressed:
|
||||
- Most MoE ops only support similar quantization strategies for
|
||||
each parameter, e.g. both weights must have the same GroupShape
|
||||
and both activations must share the same GroupShape. One exception to
|
||||
this is the cutlass moe which allows per channel quantization on the
|
||||
outputs. Note: this restrictions are not always rigorously checked.
|
||||
- Not all fused MoE functions support all the parameters, e.g. zero points,
|
||||
global scales, alphas and biases are not universally supported.
|
||||
- Fully general GroupShapes are not allowed. Activations only support
|
||||
per token, per tensor or K-blocked.
|
||||
- Weights are not required to have a GroupShape since they have already
|
||||
been quantized.
|
||||
|
||||
Other notes:
|
||||
- PrecisionConfigs are specific to GPT OSS Triton.
|
||||
- As a follow up it would probably make sense to subclass FusedMoEQuantDesc
|
||||
or FusedMoEQuantConfig for particular FusedMoEMethodBase subclasses
|
||||
so that only the required quantization parameters are used/stored.
|
||||
"""
|
||||
|
||||
# TODO(bnell) make sure a1_scales/a2_scales don't interfere with chunking
|
||||
_a1: FusedMoEQuantDesc
|
||||
_a2: FusedMoEQuantDesc
|
||||
_w1: FusedMoEQuantDesc
|
||||
_w2: FusedMoEQuantDesc
|
||||
|
||||
def __post_init__(self):
|
||||
assert (not self.per_act_token_quant
|
||||
or self.block_shape is None), "illegal quantization"
|
||||
|
||||
#
|
||||
# Convenience accessors for various properties.
|
||||
#
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Union[torch.dtype, str, None]:
|
||||
return self._a1.dtype
|
||||
|
||||
@property
|
||||
def is_quantized(self) -> bool:
|
||||
return self.quant_dtype is not None
|
||||
|
||||
@property
|
||||
def is_per_act_token(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def per_act_token_quant(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def per_out_ch_quant(self) -> bool:
|
||||
return self._w1.shape == GroupShape.PER_TOKEN
|
||||
|
||||
@property
|
||||
def is_per_tensor(self) -> bool:
|
||||
return self._a1.shape == GroupShape.PER_TENSOR
|
||||
|
||||
@property
|
||||
def block_shape(self) -> Optional[list[int]]:
|
||||
if (self._a1.shape is not None
|
||||
and self._a1.shape != GroupShape.PER_TENSOR
|
||||
and self._a1.shape != GroupShape.PER_TOKEN):
|
||||
return [self._a1.shape.row, self._a1.shape.col]
|
||||
else:
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_block_quantized(self) -> bool:
|
||||
return self.block_shape is not None
|
||||
|
||||
@property
|
||||
def a1_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._a1.scale is None or isinstance(self._a1.scale,
|
||||
torch.Tensor)
|
||||
return self._a1.scale
|
||||
|
||||
@property
|
||||
def a1_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self._a1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def a2_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._a2.scale is None or isinstance(self._a2.scale,
|
||||
torch.Tensor)
|
||||
return self._a2.scale
|
||||
|
||||
@property
|
||||
def a2_gscale(self) -> Optional[torch.Tensor]:
|
||||
return self._a2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w1_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale,
|
||||
torch.Tensor)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def w1_zp(self) -> Optional[torch.Tensor]:
|
||||
return self._w1.zp
|
||||
|
||||
@property
|
||||
def w1_bias(self) -> Optional[torch.Tensor]:
|
||||
return self._w1.bias
|
||||
|
||||
@property
|
||||
def w1_precision(self) -> Optional["PrecisionConfig"]:
|
||||
assert self._w1.scale is None or isinstance(self._w1.scale,
|
||||
PrecisionConfig)
|
||||
return self._w1.scale
|
||||
|
||||
@property
|
||||
def g1_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self._w1.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def w2_scale(self) -> Optional[torch.Tensor]:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale,
|
||||
torch.Tensor)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def w2_zp(self) -> Optional[torch.Tensor]:
|
||||
return self._w2.zp
|
||||
|
||||
@property
|
||||
def w2_bias(self) -> Optional[torch.Tensor]:
|
||||
return self._w2.bias
|
||||
|
||||
@property
|
||||
def w2_precision(self) -> Optional["PrecisionConfig"]:
|
||||
assert self._w2.scale is None or isinstance(self._w2.scale,
|
||||
PrecisionConfig)
|
||||
return self._w2.scale
|
||||
|
||||
@property
|
||||
def g2_alphas(self) -> Optional[torch.Tensor]:
|
||||
return self._w2.alpha_or_gscale
|
||||
|
||||
@property
|
||||
def use_fp8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.float8_e4m3fn
|
||||
|
||||
@property
|
||||
def use_int8_w8a8(self) -> bool:
|
||||
return self.quant_dtype == torch.int8
|
||||
|
||||
@property
|
||||
def use_int8_w8a16(self) -> bool:
|
||||
return (self._a1.dtype is None and self._w1.dtype == torch.int8)
|
||||
|
||||
@property
|
||||
def use_int4_w4a16(self) -> bool:
|
||||
return (self._a1.dtype is None and self._w1.dtype == "int4")
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a4(self) -> bool:
|
||||
return (self._a1.dtype == "mxfp4" and self._w1.dtype == "mxfp4")
|
||||
|
||||
@property
|
||||
def use_mxfp4_w4a16(self) -> bool:
|
||||
return (self._a1.dtype is None and self._w1.dtype == "mxfp4")
|
||||
|
||||
@property
|
||||
def use_nvfp4_w4a4(self) -> bool:
|
||||
return self.quant_dtype == "nvfp4"
|
||||
|
||||
def config_name(self, dtype: torch.dtype) -> Optional[str]:
|
||||
"""
|
||||
Return a string used to construct the filename that contains the
|
||||
tuning info for a particular quantization scheme. See
|
||||
try_get_optimal_moe_config in fused_moe.py.
|
||||
"""
|
||||
return _get_config_dtype_str(
|
||||
use_fp8_w8a8=self.use_fp8_w8a8,
|
||||
use_int8_w8a16=self.use_int8_w8a16,
|
||||
use_int4_w4a16=self.use_int4_w4a16,
|
||||
use_mxfp4_w4a4=self.use_mxfp4_w4a4,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def scale_shape(
|
||||
self,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int]]:
|
||||
"""
|
||||
Construct the proper activation scale shape for this
|
||||
config.
|
||||
"""
|
||||
if self.is_quantized:
|
||||
if self.is_block_quantized:
|
||||
assert self.block_shape is not None
|
||||
_, block_k = self.block_shape
|
||||
k_tiles = cdiv(hidden_dim, block_k)
|
||||
return (max_tokens, k_tiles)
|
||||
elif self.is_per_act_token:
|
||||
return (max_tokens, 1)
|
||||
else:
|
||||
return (1, 1)
|
||||
else:
|
||||
return None
|
||||
|
||||
def batched_scale_shape(
|
||||
self,
|
||||
num_experts: int,
|
||||
max_tokens: int,
|
||||
hidden_dim: int,
|
||||
) -> Optional[tuple[int, int, int]]:
|
||||
"""
|
||||
Construct the proper activation batched scale shape for this
|
||||
config, e.g. (num experts, *scale_shape).
|
||||
"""
|
||||
if self.is_quantized:
|
||||
scale_shape = self.scale_shape(max_tokens, hidden_dim)
|
||||
assert scale_shape is not None
|
||||
return (num_experts, *scale_shape)
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def make(
|
||||
quant_dtype: Union[torch.dtype, str, None] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig", None] = None,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
g1_alphas: Optional[torch.Tensor] = None,
|
||||
g2_alphas: Optional[torch.Tensor] = None,
|
||||
a1_gscale: Optional[torch.Tensor] = None,
|
||||
a2_gscale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
w1_zp: Optional[torch.Tensor] = None,
|
||||
w2_zp: Optional[torch.Tensor] = None,
|
||||
) -> "FusedMoEQuantConfig":
|
||||
"""
|
||||
General builder function for a FusedMoEQuantConfig.
|
||||
- quant_dtype: Optional quantization type. None if activations are
|
||||
unquantized or quantized prior to calling. Note: "nvfp4" and
|
||||
"mxfp4" are the only valid string values for quant_dtype.
|
||||
- per_act_token_quant: Activations have per token quantization.
|
||||
- per_out_ch_quant: Outputs have per channel quantization. (only
|
||||
for cutlass).
|
||||
- block_shape: Optional block size for block-wise quantization.
|
||||
Incompatible with per_act_token and per_out_ch quant.
|
||||
- w1_scale: Optional scale to be used for w1.
|
||||
- w2_scale: Optional scale to be used for w2.
|
||||
- a1_scale: Optional scale to be used for a1.
|
||||
- a2_scale: Optional scale to be used for a2.
|
||||
- g1_alphas: Optional global quantization scales for w1 (for nvfp4).
|
||||
- g2_alphas: Optional global quantization scales for w2 (for nvfp4).
|
||||
- a1_gscale: Optional global quantization scales for a1 (for nvfp4).
|
||||
- a2_gscale: Optional global quantization scales for a2 (for nvfp4).
|
||||
- w1_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w2_bias: Optional biases for w1 (GPT OSS Triton).
|
||||
- w1_zp: Optional w1 zero points for int4/int8 quantization.
|
||||
- w2_zp: Optional w2 zero points for int4/int8 quantization.
|
||||
"""
|
||||
assert (not isinstance(quant_dtype, str) or quant_dtype == "nvfp4"
|
||||
or quant_dtype == "mxfp4")
|
||||
a_shape, w_shape = _quant_flags_to_group_shape(quant_dtype,
|
||||
per_act_token_quant,
|
||||
per_out_ch_quant,
|
||||
block_shape)
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(quant_dtype, a_shape, a1_scale, a1_gscale),
|
||||
_a2=FusedMoEQuantDesc(quant_dtype, a_shape, a2_scale, a2_gscale),
|
||||
_w1=FusedMoEQuantDesc(quant_dtype, w_shape, w1_scale, g1_alphas,
|
||||
w1_zp, w1_bias),
|
||||
_w2=FusedMoEQuantDesc(quant_dtype, w_shape, w2_scale, g2_alphas,
|
||||
w2_zp, w2_bias),
|
||||
)
|
||||
assert quant_config.per_act_token_quant == per_act_token_quant
|
||||
assert quant_config.per_out_ch_quant == per_out_ch_quant
|
||||
assert quant_config.block_shape == block_shape
|
||||
return quant_config
|
||||
|
||||
|
||||
def fp8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
per_act_token_quant: bool = False,
|
||||
per_out_ch_quant: bool = False,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for fp8 activations and fp8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(torch.float8_e4m3fn,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=per_out_ch_quant,
|
||||
block_shape=block_shape)
|
||||
|
||||
|
||||
def int8_w8a8_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
a1_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
per_act_token_quant: bool = False,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for int8 activations and int8 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
torch.int8,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
per_act_token_quant=per_act_token_quant,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a16_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc("mxfp4", None, w1_scale, None, None, w1_bias),
|
||||
_w2=FusedMoEQuantDesc("mxfp4", None, w2_scale, None, None, w2_bias),
|
||||
)
|
||||
|
||||
|
||||
def mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
w2_scale: Union[torch.Tensor, "PrecisionConfig"],
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
w1_bias: Optional[torch.Tensor] = None,
|
||||
w2_bias: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and mxfp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
"mxfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
|
||||
def nvfp4_moe_quant_config(
|
||||
g1_alphas: torch.Tensor,
|
||||
g2_alphas: torch.Tensor,
|
||||
a1_gscale: torch.Tensor,
|
||||
a2_gscale: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for mxfp4 activations and nvp4 weights.
|
||||
"""
|
||||
return FusedMoEQuantConfig.make(
|
||||
"nvfp4",
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
a1_gscale=a1_gscale,
|
||||
a2_gscale=a2_gscale,
|
||||
g1_alphas=g1_alphas,
|
||||
g2_alphas=g2_alphas,
|
||||
per_act_token_quant=False,
|
||||
per_out_ch_quant=False,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
|
||||
def int4_w4a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int4 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||
_w1=FusedMoEQuantDesc("int4", group_shape, w1_scale, None, w1_zp),
|
||||
_w2=FusedMoEQuantDesc("int4", group_shape, w2_scale, None, w2_zp),
|
||||
)
|
||||
|
||||
|
||||
def int8_w8a16_moe_quant_config(
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
w1_zp: Optional[torch.Tensor],
|
||||
w2_zp: Optional[torch.Tensor],
|
||||
block_shape: Optional[list[int]] = None,
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for 16-bit float activations and int8 weights.
|
||||
Note: Activations are pre-quantized.
|
||||
"""
|
||||
group_shape = GroupShape(*block_shape) if block_shape is not None else None
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(shape=group_shape),
|
||||
_a2=FusedMoEQuantDesc(shape=group_shape),
|
||||
_w1=FusedMoEQuantDesc(torch.int8, group_shape, w1_scale, None, w1_zp),
|
||||
_w2=FusedMoEQuantDesc(torch.int8, group_shape, w2_scale, None, w2_zp),
|
||||
)
|
||||
|
||||
|
||||
def biased_moe_quant_config(
|
||||
w1_bias: Optional[torch.Tensor],
|
||||
w2_bias: Optional[torch.Tensor],
|
||||
) -> FusedMoEQuantConfig:
|
||||
"""
|
||||
Construct a quant config for unquantized activations with biases.
|
||||
"""
|
||||
return FusedMoEQuantConfig(
|
||||
_a1=FusedMoEQuantDesc(),
|
||||
_a2=FusedMoEQuantDesc(),
|
||||
_w1=FusedMoEQuantDesc(bias=w1_bias),
|
||||
_w2=FusedMoEQuantDesc(bias=w2_bias),
|
||||
)
|
||||
|
||||
|
||||
# A FusedMoEQuantConfig constant for an unquantized MoE op.
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEParallelConfig:
|
||||
tp_size: int
|
||||
dp_size: int
|
||||
ep_size: int
|
||||
tp_rank: int
|
||||
dp_rank: int
|
||||
ep_rank: int
|
||||
|
||||
use_ep: bool # whether to use EP or not
|
||||
|
||||
@property
|
||||
def use_all2all_kernels(self):
|
||||
return self.dp_size > 1 and self.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "pplx")
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput")
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return (self.use_all2all_kernels
|
||||
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
|
||||
|
||||
@staticmethod
|
||||
def make(tp_size_: int, dp_size_: int,
|
||||
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
|
||||
"""
|
||||
Determine MoE parallel configuration. Based on the input `tp_size_`,
|
||||
`dp_size_` and vllm's parallel config, determine what
|
||||
level's of parallelism to use in the fused moe layer.
|
||||
|
||||
Args:
|
||||
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
|
||||
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
|
||||
vllm_parallel_config (ParallelConfig): vLLM's parallel config
|
||||
object which contains the `enable_expert_parallel` flag.
|
||||
|
||||
Examples:
|
||||
When there is no parallelism requested,
|
||||
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
|
||||
unaltered and the ranks set to 0.
|
||||
|
||||
Expert Parallelism is considered only when either `dp_size_` or
|
||||
`tp_size_` is non trivial.
|
||||
|
||||
When TP = 2, DP = 1 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
|
||||
legend : {size, rank}
|
||||
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
|
||||
- Comment : Tensors are sharded across 2 devices.
|
||||
|
||||
When TP = 1, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1 : TP = {2, 1} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 2 decvices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = False, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
|
||||
- device 1: TP = {4, 1} DP = {2, 0} EP = {1, 0}
|
||||
- device 2: TP = {4, 2} DP = {2, 1} EP = {1, 0}
|
||||
- device 3: TP = {4, 3} DP = {2, 1} EP = {1, 0}
|
||||
- Comment: There are 2 engine instances and the tensors are sharded
|
||||
across 4 devices.
|
||||
|
||||
When, TP = 2, DP = 1 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
|
||||
- Comment: The experts are split between the 2 devices.
|
||||
|
||||
When, TP = 1, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 1} EP = {2, 1}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 2 devices.
|
||||
|
||||
When TP = 2, DP = 2 and EP = True, the configuration on different
|
||||
devices:
|
||||
|
||||
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
|
||||
- device 1: TP = {1, 0} DP = {2, 0} EP = {4, 1}
|
||||
- device 2: TP = {1, 0} DP = {2, 1} EP = {4, 2}
|
||||
- device 3: TP = {1, 0} DP = {2, 1} EP = {4, 3}
|
||||
- Comment: There are 2 engine instances and the experts are split
|
||||
between the 4 devices.
|
||||
"""
|
||||
|
||||
def flatten_tp_across_dp(dp_rank: int):
|
||||
tp_rank = 0 if tp_size_ == 1 else get_tensor_model_parallel_rank()
|
||||
# There are actually dp_size_ * tp_size_ devices. Update tp_size
|
||||
# and tp_rank so we shard across all devices.
|
||||
tp_size = dp_size_ * tp_size_
|
||||
tp_rank = dp_rank * tp_size_ + tp_rank
|
||||
return tp_size, tp_rank
|
||||
|
||||
use_ep = (dp_size_ * tp_size_ > 1
|
||||
and vllm_parallel_config.enable_expert_parallel)
|
||||
|
||||
dp_size = dp_size_
|
||||
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
|
||||
tp_size, tp_rank = flatten_tp_across_dp(dp_rank)
|
||||
|
||||
if not use_ep:
|
||||
return FusedMoEParallelConfig(tp_size=tp_size,
|
||||
tp_rank=tp_rank,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=1,
|
||||
ep_rank=0,
|
||||
use_ep=False)
|
||||
# DP + EP / TP + EP / DP + TP + EP
|
||||
assert use_ep
|
||||
# In EP, each device owns a set of experts fully. There is no tensor
|
||||
# parallel update tp_size, tp_rank, ep_size and ep_rank to reflect that.
|
||||
ep_size = tp_size
|
||||
ep_rank = tp_rank
|
||||
return FusedMoEParallelConfig(tp_size=1,
|
||||
tp_rank=0,
|
||||
dp_size=dp_size,
|
||||
dp_rank=dp_rank,
|
||||
ep_size=ep_size,
|
||||
ep_rank=ep_rank,
|
||||
use_ep=True)
|
||||
|
||||
|
||||
# Adapted from pplx-kernels tests/all_to_all_utils.py
|
||||
@dataclass
|
||||
class FusedMoEConfig:
|
||||
num_experts: int
|
||||
experts_per_token: int
|
||||
hidden_dim: int
|
||||
|
||||
num_local_experts: int
|
||||
moe_parallel_config: FusedMoEParallelConfig
|
||||
|
||||
# The activation type.
|
||||
in_dtype: torch.dtype
|
||||
|
||||
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
|
||||
|
||||
has_bias: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.dp_size > 1:
|
||||
logger.debug_once("Using FusedMoEConfig::max_num_tokens=%d",
|
||||
self.max_num_tokens)
|
||||
|
||||
assert self.max_num_tokens > 0
|
||||
|
||||
@property
|
||||
def tp_size(self):
|
||||
return self.moe_parallel_config.tp_size
|
||||
|
||||
@property
|
||||
def dp_size(self):
|
||||
return self.moe_parallel_config.dp_size
|
||||
|
||||
@property
|
||||
def ep_size(self):
|
||||
return self.moe_parallel_config.ep_size
|
||||
|
||||
@property
|
||||
def tp_rank(self):
|
||||
return self.moe_parallel_config.tp_rank
|
||||
|
||||
@property
|
||||
def dp_rank(self):
|
||||
return self.moe_parallel_config.dp_rank
|
||||
|
||||
@property
|
||||
def ep_rank(self):
|
||||
return self.moe_parallel_config.ep_rank
|
||||
|
||||
@property
|
||||
def use_ep(self):
|
||||
return self.moe_parallel_config.use_ep
|
||||
|
||||
@property
|
||||
def use_pplx_kernels(self):
|
||||
return self.moe_parallel_config.use_pplx_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ht_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ht_kernels
|
||||
|
||||
@property
|
||||
def use_deepep_ll_kernels(self):
|
||||
return self.moe_parallel_config.use_deepep_ll_kernels
|
||||
|
||||
@property
|
||||
def use_flashinfer_cutlass_kernels(self):
|
||||
"""
|
||||
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
|
||||
"""
|
||||
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
|
||||
and has_flashinfer_cutlass_fused_moe()
|
||||
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 256,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,218 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"5120": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"9216": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"13312": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"17408": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"25600": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"33792": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"41984": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"50176": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"58368": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 16,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 4,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 1,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 2,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 8,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 1
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2,
|
||||
"waves_per_eu": 0,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
{
|
||||
"triton_version": "3.4.0",
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 256,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 32,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
{
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 4,
|
||||
"num_stages": 5
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8192": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16384": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 256,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user