This commit is contained in:
2026-01-09 13:34:11 +08:00
parent dfa6476b58
commit b2ef04d792
538 changed files with 105693 additions and 2 deletions

View File

View File

@@ -0,0 +1,173 @@
"""Custom activation functions."""
import math
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(nn.Module):
"""An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (num_tokens, 2 * d) or (batch_size, seq_len, 2 * d)
return: (num_tokens, d) or (batch_size, seq_len, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.silu_and_mul(out, x)
return out
class GeluAndMul(nn.Module):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def __init__(self, approximate: str = "none"):
super().__init__()
self.approximate = approximate
if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}")
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
if self.approximate == "none":
ops.gelu_and_mul(out, x)
elif self.approximate == "tanh":
ops.gelu_tanh_and_mul(out, x)
return out
def extra_repr(self) -> str:
return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_new(out, x)
return out
class FastGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
ops.gelu_fast(out, x)
return out
class ScaledActivation(nn.Module):
"""An activation function with post-scale parameters.
This is used for some quantization methods like AWQ.
"""
def __init__(
self,
act_module: nn.Module,
intermediate_size: int,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.act = act_module
self.input_is_parallel = input_is_parallel
if input_is_parallel:
tp_size = get_tensor_model_parallel_world_size()
intermediate_size_per_partition = divide(intermediate_size,
tp_size)
else:
intermediate_size_per_partition = intermediate_size
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.act(x) / self.scales
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
if self.input_is_parallel:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param_data.shape[0]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
}
def get_act_fn(
act_fn_name: str,
quant_config: Optional[QuantizationConfig] = None,
intermediate_size: Optional[int] = None,
input_is_parallel: bool = True,
params_dtype: Optional[torch.dtype] = None,
) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
if act_fn_name not in _ACTIVATION_REGISTRY:
raise ValueError(
f"Activation function {act_fn_name!r} is not supported.")
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
if (quant_config is not None
and act_fn_name in quant_config.get_scaled_act_names()):
if intermediate_size is None:
raise ValueError("intermediate_size must be specified for scaled "
"activation functions.")
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
params_dtype)
return act_fn

View File

@@ -0,0 +1,7 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_moe, get_config_file_name)
__all__ = [
"fused_moe",
"get_config_file_name",
]

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,140 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,146 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"48": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"64": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"128": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 4
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}

View File

@@ -0,0 +1,10 @@
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.

View File

@@ -0,0 +1,479 @@
"""Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
a_scale_ptr,
b_scale_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
# (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
use_fp8: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can
be any shape representing batches and K is the feature dimension of
each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is
the number of experts, K is the input feature dimension, and N is
the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the
total number of tokens post padding, topk is the number of times
each token is repeated, and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens,
repeated topk times and arranged by the expert index they are
assigned to.
- expert_ids: A tensor containing the indices of the expert for each
block. It determines which expert matrix from B should be used for
each block in A.
This kernel performs the multiplication of a token by its corresponding
expert matrix as determined by `expert_ids`. The sorting of
`sorted_token_ids` by expert index and padding ensures divisibility by
BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
if use_fp8:
a_scale = tl.load(a_scale_ptr)
b_scale = tl.load(b_scale_ptr + off_experts)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
else:
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None]
if use_fp8:
accumulator = (accumulator * a_scale * b_scale).to(compute_type)
else:
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the
top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according
to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding,
ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process
so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions
align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids
[3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in
the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
if not use_fp8:
assert A_scale is None
assert B_scale is None
else:
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
fused_moe_kernel[grid](
A,
B,
C,
A_scale,
B_scale,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=compute_type,
use_fp8=use_fp8,
**config,
)
def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
device_name = torch.musa.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
@functools.lru_cache
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the fused_moe kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default
# configuration
return None
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation
(before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place.
Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override
for the kernel configuration.
- use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
products for w1 and w2. Defaults to False.
- w1_scale (Optional[torch.Tensor]): Optional scale to be used for
w1.
- w2_scale (Optional[torch.Tensor]): Optional scale to be used for
w2.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
if override_config:
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)

View File

@@ -0,0 +1,71 @@
"""Custom normalization layers."""
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
class RMSNorm(nn.Module):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def _forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
ops.fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
out = torch.empty_like(x)
ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
def extra_repr(self) -> str:
s = f"hidden_size={self.weight.data.size(0)}"
s += f", eps={self.variance_epsilon}"
return s

View File

@@ -0,0 +1,709 @@
from abc import abstractmethod
from typing import List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
return shard_size, shard_offset
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization.
Args:
separate_bias_add: If true, add bias separately after matrix
multiplication.
"""
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
return F.linear(x, weight)
return F.linear(x.to(weight.device), weight, bias)
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0})
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, tp_size)
if output_sizes is None:
output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)
param_data = param.data
if output_dim is not None:
shard_size = param_data.shape[output_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={get_tensor_model_parallel_world_size()}"
s += f", gather_output={self.gather_output}"
return s
class MergedColumnParallelLinear(ColumnParallelLinear):
"""Packed linear layers with column parallelism.
Similar to ColumnParallelLinear, but the weight matrix is concatenated
along the output dimension. When the weight matrix is loaded, the
different partitions are sharded separately.
Args:
input_size: input dimension of the linear layer.
output_sizes: list of output dimensions of the linear layer.
bias: If true, add bias.
gather_output: If true, call all-gather on output and make the output
available to all GPUs, otherwise, every GPU will have
its own output.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_sizes: List[int],
bias: bool = True,
gather_output: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
super().__init__(input_size, sum(output_sizes), bias, gather_output,
skip_bias_add, params_dtype, quant_config,
self.output_sizes)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
shard_offsets = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id < len(self.output_sizes)
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
if output_dim is not None:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size
shard_size = self.output_sizes[loaded_shard_id] // tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
# Divide the weight matrix along the last dimension.
tp_size = get_tensor_model_parallel_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
output_sizes = [
self.num_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size,
self.num_kv_heads * tp_size * self.head_size
]
super().__init__(input_size, output_size, bias, False, skip_bias_add,
params_dtype, quant_config, output_sizes)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
is_metadata = getattr(param, "is_metadata", False)
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
if loaded_shard_id is None:
# Loaded weight is already packed.
if output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
tp_rank = get_tensor_model_parallel_rank()
assert loaded_shard_id in ["q", "k", "v"]
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# for the packing.
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
# Special case for Marlin.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
# Special case for for AQLM codebooks.
elif is_metadata:
# metadata indicates fixed size concatenated along dim 0
shard_size = loaded_weight.shape[0]
shard_index = ["q", "k", "v"].index(loaded_shard_id)
param_data = param_data.narrow(0, shard_index * shard_size,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(
param_data, loaded_weight, loaded_shard_id)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# Special case for Fp8 scales.
fp8_scales_shard_indexer = getattr(param, "fp8_scales_shard_indexer",
None)
tp_rank = get_tensor_model_parallel_rank()
input_dim = getattr(param, "input_dim", None)
param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
# Special case for Fp8 scales.
elif fp8_scales_shard_indexer is not None:
param_data, loaded_weight = fp8_scales_shard_indexer(param_data,
loaded_weight,
shard_id=0)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def forward(self, input_):
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
s += f", reduce_results={self.reduce_results}"
return s

View File

@@ -0,0 +1,115 @@
"""A layer that compute logits from hidden_stats."""
from typing import Optional
import torch
import torch.nn as nn
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.sampling_metadata import SamplingMetadata
class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None,
scale: Optional[float] = 1.0,
logits_as_input: bool = False) -> None:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super().__init__()
self.scale = scale
self.vocab_size = vocab_size
# Whether the input is logits (default is hidden states).
self.logits_as_input = logits_as_input
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.logits_as_input:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
if logits is not None:
logits *= self.scale
# Apply logits processors (if any).
logits = _apply_logits_processors(logits, sampling_metadata)
return logits
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def extra_repr(self) -> str:
s = f"vocab_size={self.vocab_size}"
s += f", forg_vocab_size={self.org_vocab_size}"
s += f", scale={self.scale}, logits_as_input={self.logits_as_input}"
return s
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
return hidden_states.index_select(0,
sampling_metadata.selected_token_indices)
def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = seq_group.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_processed += len(seq_group.sample_indices) + len(
seq_group.prompt_logprob_indices)
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
return logits

View File

@@ -0,0 +1,157 @@
from typing import Optional, Union
import torch
import triton
import triton.language as tl
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)

View File

@@ -0,0 +1,406 @@
import math
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
_EPS = 1e-6
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
def _multi_split_sample(
probs: torch.Tensor,
seeds: torch.Tensor,
n_splits: int,
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
logprobs: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
split_probs = probs.tensor_split(n_splits, 1)
split_logprobs = logprobs.tensor_split(n_splits, 1)
sampled_tokens_tmp = [
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
for _ in range(n_splits)
]
sampled_logprobs_tmp = [
torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp = [
torch.empty(sampled_tokens_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
for i in range(n_splits):
n_samples = sample_indices.shape[0]
n_cols = split_probs[i].shape[1]
n_best = sampled_tokens_tmp[i].shape[1]
uniform_noise = seeded_uniform(n_samples,
n_best,
n_cols,
seeds=seeds[i].flatten(),
device=split_probs[i].device,
dtype=split_probs[i].dtype)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample(
split_probs[i].contiguous(),
split_logprobs[i].contiguous(),
sample_indices,
sampled_tokens_tmp[i],
sampled_logprobs_tmp[i],
sampled_modified_probs_tmp[i],
seeds[i],
uniform_noise,
modify_greedy_probs=False,
save_logprobs=save_logprobs,
save_modified_probs=True,
)
if i > 0:
# Add offset to sampled tokens
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
sampled_tokens = torch.stack(sampled_tokens_tmp)
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
# Reduce the results from the splits.
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
dim=0,
keepdim=True)
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
if save_logprobs:
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
else:
sampled_logprobs = None
sampled_modified_probs = sampled_modified_probs.squeeze(0)
if modify_greedy_probs:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs.fill_(0.0)
probs.scatter_(1, sampled_tokens, 1.0)
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
def sample(
probs: torch.Tensor,
seeds: torch.Tensor,
*,
max_best_of: int = 1,
sample_indices: Optional[torch.Tensor] = None,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
_save_modified_probs: bool = False, # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if sample_indices is None:
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
sampled_tokens_size = (sample_indices.size(0), max_best_of)
if save_logprobs:
if logprobs is None:
raise ValueError(
"logprobs tensor must be provided if save_logprobs is True")
sampled_logprobs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_logprobs_size = (0, 0)
logprobs = probs
assert logprobs is not None
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_modified_probs_size = (0, 0)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if n_splits > 1:
(sampled_tokens, sampled_logprobs,
sampled_modified_probs) = _multi_split_sample(
probs,
seeds,
n_splits,
sampled_tokens_size,
sampled_logprobs_size,
sample_indices,
logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs)
else:
sampled_tokens = torch.empty(sampled_tokens_size,
dtype=torch.long,
device=probs.device)
sampled_logprobs = torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device)
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
dtype=probs.dtype,
device=probs.device)
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
uniform_noise = seeded_uniform(n_samples,
max_best_of,
n_cols,
seeds=seeds.flatten(),
device=probs.device,
dtype=probs.dtype)
_sample(
probs,
logprobs,
sample_indices,
sampled_tokens,
sampled_logprobs,
sampled_modified_probs,
seeds,
uniform_noise,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=_save_modified_probs,
)
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
sampled_modified_probs if _save_modified_probs else None)
def _sample(probs: torch.Tensor,
logprobs: torch.Tensor,
sample_indices: torch.Tensor,
output_samples: torch.Tensor,
output_logprobs: torch.Tensor,
output_modified_probs: torch.Tensor,
seeds: torch.Tensor,
uniform_noise: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = True,
save_modified_probs: bool = False) -> torch.Tensor:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size = triton.next_power_of_2(n_cols)
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
num_warps = 32
elif block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton[(n_samples, n_best)](
sample_indices,
output_samples,
output_logprobs,
output_modified_probs,
probs,
logprobs,
seeds,
uniform_noise,
output_samples.stride(0),
probs.stride(0),
uniform_noise.stride(0),
uniform_noise.stride(1) if n_best > 1 else 1,
n_samples,
n_cols,
n_best,
num_warps=num_warps,
block_size=block_size,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=save_modified_probs,
)
return output_samples, output_logprobs, output_modified_probs
@triton.jit
def _uniform_to_exponential(uniform_noise):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
uniform_noise = tl.maximum(uniform_noise, lb)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise = -tl.log(uniform_noise)
return exponential_noise
@triton.jit
def _sample_triton(
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
output_logprobs_ptr: torch.Tensor,
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
probs_row_stride: int, uniform_noise_row_stride: int,
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
n_best: int, block_size: tl.constexpr,
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
save_modified_probs: tl.constexpr):
# The rows are independent, so we parallelize across those
sample_idx = tl.program_id(0)
best_idx = tl.program_id(1)
# Load the row index from DRAM
row_idx = tl.load(sample_indices_ptr + sample_idx)
seed = tl.load(seeds_ptr + sample_idx)
uses_random_sampling = seed != 0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr = probs_ptr + row_idx * probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets = tl.arange(0, block_size)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row = tl.load(row_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"))
if uses_random_sampling:
uniform_noise_start_ptr = (uniform_noise_ptr +
sample_idx * uniform_noise_row_stride +
best_idx * uniform_noise_best_stride)
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=0.5)
exponential_noise = _uniform_to_exponential(uniform_noise)
row /= exponential_noise
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if sampled_token >= n_cols:
sampled_token = n_cols - 1
# Write back output to DRAM
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
best_idx)
tl.store(output_row_start_ptr, sampled_token)
if modify_greedy_probs: # noqa
if not uses_random_sampling:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
tl.store(row_start_ptr + col_offsets,
row,
mask=col_offsets < n_cols)
if save_modified_probs:
output_row_start_ptr = (output_modified_probs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_value)
if save_logprobs:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
sampled_token)
# Write back output to DRAM
output_row_start_ptr = (output_logprobs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_logprob)

View File

@@ -0,0 +1,35 @@
from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"fp8": Fp8Config,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig,
"marlin": MarlinConfig,
}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(f"Invalid quantization method: {quantization}")
return QUANTIZATION_METHODS[quantization]
__all__ = [
"QuantizationConfig",
"get_quantization_config",
"QUANTIZATION_METHODS",
]

View File

@@ -0,0 +1,376 @@
# Supports AQLM compression, see https://github.com/Vahe1994/AQLM
# and https://arxiv.org/pdf/2401.06118.pdf
import math
from typing import Any, Dict, List, Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
def get_int_dtype(nbits: int) -> torch.dtype:
if nbits <= 8:
return torch.int8
if nbits <= 16:
return torch.int16
if nbits <= 32:
return torch.int32
if nbits <= 64:
return torch.int64
raise ValueError(f"No dtype available for {nbits}-bit codebooks")
@torch.inference_mode()
def unpack_int_data(data: torch.IntTensor, nbits: int) -> torch.IntTensor:
return data.to(torch.int64) % (2**nbits)
def dequantize_weight(codes: torch.Tensor,
codebooks: torch.Tensor,
scales: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Decode float weights from quantization codes. Differentiable.
:param codes: tensor of integer quantization codes, shape
[*dims, num_out_groups, num_in_groups, num_codebooks]
:param codebooks: tensor of vectors for each quantization code,
[num_codebooks, codebook_size, out_group_size, in_group_size]
:param scales: weight will be multiplied by this factor, must be
broadcastble with
[*dims, out_groups, num_in_groups, out_group_size, in_group_size]
:return: reconstructed weight tensor of shape
[*dims, num_in_groups*group_size]
"""
num_out_groups, num_in_groups, num_codebooks = codes.shape[-3:]
num_codebooks, codebook_size, out_group_size, in_group_size = \
codebooks.shape
out_features = num_out_groups * out_group_size
in_features = num_in_groups * in_group_size
codebook_offsets = torch.arange(
0, num_codebooks * codebook_size, codebook_size,
device=codes.device) # shape: [num_codebooks]
reconstructed_weight_flat = F.embedding_bag(
codes.flatten(0, -2) + codebook_offsets,
codebooks.flatten(0, 1).flatten(-2, -1),
mode="sum"
) # [prod(dims) * num_out_groups * num_in_groups, out_group_size
# * in_group_size]
reconstructed_weight_groupwise = reconstructed_weight_flat.view(
list(codes.shape[:-3]) +
[num_out_groups, num_in_groups, out_group_size, in_group_size])
if scales is not None:
reconstructed_weight_groupwise = reconstructed_weight_groupwise.mul(
scales)
return reconstructed_weight_groupwise.swapaxes(
-3, -2).reshape(list(codes.shape[:-3]) + [out_features, in_features])
def dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
bias: Optional[torch.Tensor],
) -> torch.Tensor:
dequantized_weight = dequantize_weight(
unpack_int_data(codes, codebooks.shape[1].bit_length() - 1),
codebooks,
scales,
)
return F.linear(input, dequantized_weight, bias)
# Generic dequantization, slow but flexible.
def generic_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
output_shape = input.shape[:-1] + (scales.shape[0], )
output = torch.empty(output_shape, dtype=input.dtype, device=input.device)
num_outputs = len(output_partition_sizes)
# break the inputs and codebooks apart then combine the outputs.
# Surprisingly (to me) this is faster than doing 3 de-quants and 1 big
# multiply at the end.
num_codebooks = codebooks.shape[0] // num_outputs
assert (scales.shape[0] == codes.shape[0])
assert (sum(output_partition_sizes) == scales.shape[0])
output_offset = 0
codebooks_offset = 0
for output_size in output_partition_sizes:
shard_output = dequantize_gemm(
input, codes.narrow(0, output_offset, output_size),
codebooks.narrow(0, codebooks_offset, num_codebooks),
scales.narrow(0, output_offset, output_size), None
if bias is None else bias.narrow(0, output_offset, output_size))
output_slice = output.narrow(-1, output_offset, output_size)
assert (output_slice.shape == shard_output.shape)
output_slice.copy_(shard_output)
output_offset += output_size
codebooks_offset += num_codebooks
return output
# Optimized dequnantize/decompression kernels, supports 1x16 and 2x8
# at 6 and 9 times faster than the generic version above, respectively.
def optimized_dequantize_gemm(
input: torch.Tensor, # [..., in_features]
codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks]
codebooks: torch.
Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size]
scales: torch.Tensor, # [num_out_groups, 1, 1, 1]
output_partition_sizes: torch.IntTensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes)
if bias is None:
# scaling the output is fastest, so we do that when possible.
output = F.linear(input, weights, bias)
orig_shape = output.shape
flattened_output = output.view(-1, output.size(-1))
f_scales = scales.view(-1, scales.shape[0])
b_scales = f_scales.expand(flattened_output.shape[0], -1)
flattened_output *= b_scales
return output.view(orig_shape)
else:
b_scales = scales.view(scales.shape[:-3] + (-1, )).expand(
-1, weights.shape[1])
weights *= b_scales
return F.linear(input, weights, bias)
class AQLMConfig(QuantizationConfig):
"""Config class for AQLM.
Reference: https://github.com/Vahe1994/AQLM
"""
def __init__(
self,
in_group_size: int,
nbits_per_codebook: int,
num_codebooks: int,
out_group_size: int,
) -> None:
self.in_group_size = in_group_size
self.nbits_per_codebook = nbits_per_codebook
self.num_codebooks = num_codebooks
self.out_group_size = out_group_size
# out_group_size > 1 is untested, and probably won't work as-is.
assert (self.out_group_size == 1)
self.pack_factor = (self.in_group_size * self.out_group_size)
def __repr__(self) -> str:
return (f"AQLMConfig(in_group_size={self.in_group_size}, "
f"nbits_per_codebook={self.nbits_per_codebook}, "
f"num_codebooks={self.num_codebooks}, "
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> str:
return "aqlm"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return [] # no extra configs.
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig":
in_group_size = cls.get_from_keys(config, ["in_group_size"])
nbits_per_codebook = cls.get_from_keys(config, ["nbits_per_codebook"])
num_code_books = cls.get_from_keys(config, ["num_codebooks"])
out_group_size = cls.get_from_keys(config, ["out_group_size"])
return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase):
return AQLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class AQLMLinearMethod(LinearMethodBase):
"""Linear method for AQLM.
Args:
quant_config: The AQLM quantization config.
"""
def __init__(self, quant_config: AQLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
del output_size # Unused.
del input_size # Unused.
if params_dtype != torch.half:
raise ValueError("Only half is currently supported by aqlm")
if input_size_per_partition % self.quant_config.in_group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.out_group_size != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
codes = Parameter(
torch.empty(
# There could actually be two pack factors, one along input and
# one along output, but we don't currently support
# out_group_size, and only the one along output needs to be
# marked with "packed_dim" in order for QKVLinear to work.
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
self.quant_config.num_codebooks,
dtype=get_int_dtype(self.quant_config.nbits_per_codebook),
),
requires_grad=False,
)
set_weight_attrs(
codes,
{
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
codebooks = Parameter(
torch.empty(
self.quant_config.num_codebooks * len(output_partition_sizes),
2**self.quant_config.nbits_per_codebook,
self.quant_config.out_group_size,
self.quant_config.in_group_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
codebooks,
{
# metadata indicates fixed size concatenated along dim 0
"is_metadata":
True,
"output_partition_sizes":
torch.tensor(output_partition_sizes, device='cpu'),
},
)
scales = Parameter(
torch.empty(
(
output_size_per_partition //
self.quant_config.out_group_size,
1,
1,
1,
),
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"output_dim": 0,
"packed_dim": 0,
"pack_factor": self.quant_config.out_group_size
},
)
layer.register_parameter("codes", codes)
set_weight_attrs(codes, extra_weight_attrs)
layer.register_parameter("codebooks", codebooks)
set_weight_attrs(codebooks, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
codebooks = layer.codebooks
codes = layer.codes
scales = layer.scales
output_partition_sizes = getattr(codebooks, "output_partition_sizes",
None)
nbooks = codes.shape[2]
ingroups = codebooks.shape[3]
outgroups = codebooks.shape[2]
bits = codebooks.shape[1]
# We support these formats with dedicated gemm and decompression
# kernels.
if ingroups == 8 and outgroups == 1 and (
(bits == 256 and nbooks == 2) or (bits == 65536 and nbooks == 1)):
# thresholds determined by timings on an A6000, one GPU
use_gemv = math.prod(x.shape[:-1]) <= 6
return ops.aqlm_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
) if use_gemv else optimized_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)
# fall back all unoptimized formats
return generic_dequantize_gemm(
x,
codes,
codebooks,
scales,
output_partition_sizes,
bias,
)

View File

@@ -0,0 +1,175 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
Reference: https://arxiv.org/abs/2306.00978
"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.zero_point = zero_point
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"AWQ, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return (f"AWQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point})")
def get_name(self) -> str:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
@staticmethod
def get_config_filenames() -> List[str]:
return [
"quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
# E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase):
return AWQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class AWQLinearMethod(LinearMethodBase):
"""Linear method for AWQ.
Args:
quant_config: The AWQ quantization config.
"""
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
qweight = Parameter(
torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
qzeros = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)

View File

@@ -0,0 +1,97 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
import torch
from torch import nn
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module, *weight_args,
**extra_weight_attrs):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
def process_weights_after_loading(self, layer: nn.Module) -> None:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class QuantizationConfig(ABC):
"""Base class for quantization configs."""
@abstractmethod
def get_name(self) -> str:
"""Name of the quantization method."""
raise NotImplementedError
@abstractmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
"""List of supported activation dtypes."""
raise NotImplementedError
@abstractmethod
def get_min_capability(self) -> int:
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
"""
raise NotImplementedError
@staticmethod
@abstractmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory."""
raise NotImplementedError
@classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
"""Create a config class from the model's quantization config."""
raise NotImplementedError
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""
for key in keys:
if key in config:
return config[key]
raise ValueError(f"Cannot find any of {keys} in the model's "
"quantization config.")
@abstractmethod
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise NotImplementedError
@abstractmethod
def get_scaled_act_names(self) -> List[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError

View File

@@ -0,0 +1,265 @@
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
logger = init_logger(__name__)
class Fp8Config(QuantizationConfig):
"""Config class for FP8."""
def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning("Detected fp8 checkpoint. Please note that the "
"format is experimental and subject to change.")
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
@classmethod
def get_name(cls) -> str:
return "fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic
activation scaling. The weight scaling factor will be initialized after
the model weights are loaded.
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn data type due to the limitation of
torch._scaled_mm (https://github.com/pytorch/pytorch/blob/2e48b39603411a41c5025efbe52f89560b827825/aten/src/ATen/native/cuda/Blas.cpp#L854-L856)
Args:
quant_config: The quantization config.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def _create_scale_param(
self,
scale_name: str,
layer: torch.nn.Module,
output_partition_sizes: List[int],
**extra_weight_attrs,
) -> None:
scale = Parameter(torch.empty(len(output_partition_sizes),
dtype=torch.float32),
requires_grad=False)
layer.register_parameter(scale_name, scale)
set_weight_attrs(
scale, {
**extra_weight_attrs,
"fp8_scales_shard_indexer":
self.scales_shard_indexer,
})
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
layer.process_after_load = True
layer.logical_widths = output_partition_sizes
# WEIGHT
weight_dtype = (torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized else
params_dtype)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
**extra_weight_attrs,
"input_dim": 1,
"output_dim": 0,
})
# If checkpoint is serialized fp8, load them.
# Otherwise, wait until process_weights_after_loading.
if self.quant_config.is_checkpoint_fp8_serialized:
# WEIGHT SCALE
self._create_scale_param(
scale_name="weight_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
# ACTIVATION SCALE
if self.quant_config.activation_scheme == "static":
self._create_scale_param(
scale_name="act_scale",
layer=layer,
output_partition_sizes=output_partition_sizes,
**extra_weight_attrs)
def scales_shard_indexer(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int]) -> Tuple[torch.Tensor, torch.Tensor]:
qkv_idxs = {"q": 0, "k": 1, "v": 2}
if isinstance(shard_id, int):
pass
elif isinstance(shard_id, str):
if shard_id not in qkv_idxs:
raise ValueError(f"Unknown shard_id: {shard_id}")
shard_id = qkv_idxs[shard_id]
else:
ValueError(f"Shard id must be int or str but got {type(shard_id)}")
return param[shard_id], loaded_weight
def process_weights_after_loading(self, layer: Module) -> None:
if (not hasattr(layer, "process_after_load")
or not layer.process_after_load):
return
# If checkpoint is fp/bf16 (not serialized fp8), quantize the weights.
if not self.quant_config.is_checkpoint_fp8_serialized:
qweight, weight_scale = ops.scaled_fp8_quant(layer.weight,
scale=None)
layer.weight = Parameter(qweight.t(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
layer.logical_widths = None
layer.act_scale = None
return
# If checkpoint is fp8, requantize the separately quantized logical
# weights into a single fp8 weight with a single weight scale.
else:
# WEIGHT_SCALE / WEIGHT
# Loop over logical weights, requantizing with single scale.
max_w_scale = layer.weight_scale.max()
start = 0
for idx, logical_width in enumerate(layer.logical_widths):
end = start + logical_width
weight_dq = per_tensor_dequantize(layer.weight[start:end, :],
layer.weight_scale[idx])
layer.weight[start:end, :] = per_tensor_quantize(
weight_dq, layer.weight_scale.max())
start = end
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# WEIGHT
# Transpose weight for passing to torch._scaled_mm
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
# ACT_SCALE
# Dynamic: set to None (required input to ops.scaled_fp8_quant).
# Static: set to max of the act_scales (since they are equal).
if self.quant_config.activation_scheme == "dynamic":
layer.act_scale = None
elif self.quant_config.activation_scheme == "static":
if not all_close_1d(layer.act_scale):
raise ValueError(
"All the act_scales for the logical weights of a layer "
f"must be equal. But got {layer.act_scale}")
layer.act_scale = Parameter(layer.act_scale.max(),
requires_grad=False)
else:
raise ValueError(
f"Unknown scheme {self.quant_config.activation_scheme}")
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale)
# Fused GEMM_DQ
output, _ = torch._scaled_mm(
qinput,
layer.weight,
out_dtype=x.dtype,
scale_a=x_scale,
scale_b=layer.weight_scale,
bias=bias,
)
return output
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
def per_tensor_quantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn)
qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max)
return qweight.to(torch.float8_e4m3fn)
def per_tensor_dequantize(tensor: torch.Tensor,
inv_scale: float) -> torch.Tensor:
fake_qweight = tensor.to(torch.float16)
dq_weight = fake_qweight * inv_scale
return dq_weight

View File

@@ -0,0 +1,224 @@
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.pack_factor = Fraction(32, self.weight_bits)
if self.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ExllamaState(Enum):
UNUSED = enum.auto()
UNINITIALIZED = enum.auto()
READY = enum.auto()
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.
Args:
quant_config: The GPTQ quantization config.
"""
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
g_idx = Parameter(
torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
qzeros = Parameter(
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("g_idx", g_idx)
set_weight_attrs(g_idx, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.exllama_state = exllama_state
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
else:
layer.g_idx.data = torch.empty((0, ),
device=layer.g_idx.device)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, layer.qweight, layer.qzeros,
layer.scales, layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)

View File

@@ -0,0 +1,438 @@
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits):
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s, size_k, size_n, group_size, num_bits):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = Parameter(
torch.empty(
g_idx.shape,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)

View File

@@ -0,0 +1,227 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def __init__(
self,
group_size: int,
) -> None:
# Group size for the quantization.
self.group_size = group_size
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 64
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size})"
@classmethod
def get_name(cls) -> str:
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return MarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def __init__(self, quant_config: MarlinConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
scales = Parameter(
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B
scales = layer.s
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output

View File

@@ -0,0 +1,84 @@
"""
This file contains the Pydantic schemas for various quantization-related
parameters. When a relevant quantization technique is specified, these
parameters are loaded in the form of a JSON alongside the model weights
and augment the model with additional information needed for use of that
technique. The format of this JSON should be specified by one or more
schemas contained here.
For example, when the KV cache is quantized to FP8-E4M3 (currently only
possible on ROCm), the model can be optionally augmented with KV cache
scaling factors.
"""
from typing import Dict, Optional
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
class KVCacheQuantSchema(BaseModel):
dtype: str
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
# layer indices to their per-tensor KV cache scaling factor.
# TODO: Consider pulling this and its validation methods out into its
# own schema class (tricky as its members are variable)
scaling_factor: Dict[int, Dict[int, float]]
@model_validator(mode="after")
def check_is_fp8(self) -> "KVCacheQuantSchema":
assert self.dtype == "float8_e4m3fn", (
"Loaded scaling factors intended for KV cache dtype = "
f"{self.dtype} rather than float8_e4m3fn!")
return self
@model_validator(mode="after")
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_size = context["tp_size"]
num_hidden_layers = context["num_hidden_layers"]
assert len(self.scaling_factor) == tp_size, (
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
f"but LLM engine is currently running with TP size {tp_size}.")
for tp_rank, layer_maps in self.scaling_factor.items():
assert len(layer_maps) == num_hidden_layers, (
f"KV cache scales map for TP rank {tp_rank} is malformed. "
f"Expected {num_hidden_layers} layers, got "
f"{len(layer_maps)}.")
for i in range(tp_size):
assert i in self.scaling_factor, (
f"KV cache scales map for TP rank {i} not found.")
return self
@model_validator(mode="after")
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
context = info.context
if context:
tp_rank = context["tp_rank"]
num_hidden_layers = context["num_hidden_layers"]
layer_scales_map = self.scaling_factor[tp_rank]
for i in range(num_hidden_layers):
assert i in layer_scales_map, (
f"Could not find KV cache scales for layer {i} in "
f"TP rank {tp_rank}.")
return self
class QuantParamSchema(BaseModel):
# TODO: Generalize and extend with more fields
# (e.g. weights/activations params) once functionality is enabled
model_config = ConfigDict(protected_namespaces=())
model_type: Optional[str]
kv_cache: KVCacheQuantSchema
@model_validator(mode="after")
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
context = info.context
if context:
model_type = context.get("model_type", None)
if model_type is not None:
assert model_type == self.model_type, (
f"Model type is {model_type} but loaded "
f"scaling factors belonging to different "
f"model type {self.model_type}!")
return self

View File

@@ -0,0 +1,137 @@
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_hip
class SqueezeLLMConfig(QuantizationConfig):
"""Config class for SqueezeLLM.
Reference: https://arxiv.org/pdf/2306.07629
"""
def __init__(
self,
weight_bits: int,
) -> None:
self.weight_bits = weight_bits
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"SqueezeLLM, but got {self.weight_bits} bits.")
self.pack_factor = 32 // self.weight_bits
def __repr__(self) -> str:
return f"SqueezeLLMConfig(weight_bits={self.weight_bits})"
def get_name(self) -> str:
return "squeezellm"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
def get_min_capability(self) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
return ["quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig":
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SqueezeLLMLinearMethod(QuantizeMethodBase):
"""Linear method for SqueezeLLM.
Args:
quant_config: The SqueezeLLM quantization config.
"""
def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
lookup_table = Parameter(
torch.empty(
output_size,
self.quant_config.weight_bits**2,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(lookup_table, {
"output_dim": 0,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("lookup_table", lookup_table)
set_weight_attrs(lookup_table, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
lookup_table = layer.lookup_table
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)

View File

@@ -0,0 +1,405 @@
from functools import cached_property
from typing import Optional, Tuple
import torch
import torch.jit
import torch.nn as nn
class RejectionSampler(nn.Module):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def __init__(self, strict_mode: bool = False):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
@property
def probs_dtype(self):
return torch.float32
@property
def token_id_dtype(self):
return torch.int64
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
bonus_token_ids,
draft_token_ids)
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
target_probs,
draft_probs,
draft_token_ids,
)
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _batch_modified_rejection_sampling(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size, k, vocab_size = draft_probs.shape
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
# NOTE: the recovered_probs are overwritten by this method.
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
return accepted, recovered_token_ids
def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
same conditional probability according to the draft model, the token
is accepted with probability:
.. math::
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size, k, _ = draft_probs.shape
batch_indices = torch.arange(batch_size,
device=target_probs.device)[:, None]
probs_indicies = torch.arange(k, device=target_probs.device)
# shape [batch_size, k]
selected_draft_probs = draft_probs[batch_indices, probs_indicies,
draft_token_ids]
# shape [batch_size, k]
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
uniform_rand = torch.rand(batch_size,
k,
dtype=self.probs_dtype,
device=target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
accepted = uniform_rand < capped_ratio
return accepted
def _get_recovered_probs(
self,
target_probs: torch.Tensor, # [k, vocab_size]
draft_probs: torch.Tensor, # [k, vocab_size]
) -> torch.Tensor:
r"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
according to the draft model:
.. math::
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
where :math:`(f(x))_+` is defined as:
.. math::
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_, k, _ = draft_probs.shape
# shape [batch_size, k, vocab_size]
difference = target_probs - draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f = torch.clamp(difference, min=self._smallest_positive_value)
# shape [batch_size, k, vocab_size]
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
return recovered_probs
@cached_property
def _smallest_positive_value(self) -> float:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return torch.finfo(self.probs_dtype).tiny
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
recovered_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids = bonus_token_ids.squeeze()
batch_size, k = recovered_token_ids.shape
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
assert draft_token_ids_batch_size == draft_batch_size
assert num_draft_token_ids == num_draft_probs
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert all(probs.dtype == self.probs_dtype
for probs in [target_probs, draft_probs])
assert all(token_ids.dtype == self.token_id_dtype
for token_ids in [bonus_token_ids, draft_token_ids])
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
bonus_token_ids: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
def _multinomial(
probs: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)

View File

@@ -0,0 +1,531 @@
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Rotary Positional Embeddings."""
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from vllm import _custom_ops as ops
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
cache = self._compute_cos_sin_cache()
cache = cache.to(torch.get_default_dtype())
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
# torch_musa did not support pow_scalar_out
# inv_freq = 1.0 / (base**(torch.arange(
# 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
exp = torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
device = exp.device
inv_freq = 1.0 / (base**(exp.cpu() / self.rotary_dim))
return inv_freq.to(device)
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def _forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
query = query.flatten(-2)
key = key.flatten(-2)
return query, key
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
ops.batched_rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache,
self.is_neox_style, self.rotary_dim,
offsets)
else:
ops.rotary_embedding(positions, query, key, self.head_size,
self.cos_sin_cache, self.is_neox_style)
return query, key
def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}"
s += f", base={self.base}, is_neox_style={self.is_neox_style}"
return s
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling.
Credits to the Reddit user /u/kaiokendev
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factors: Union[List[float], float],
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors = scaling_factors
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.base)
cache_list = []
for scaling_factor in self.scaling_factors:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * scaling_factor
t = torch.arange(max_len, dtype=torch.float)
t = t / scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
cache_list.append(cache)
return torch.cat(cache_list, dim=0)
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_cos_sin_cache(self) -> torch.Tensor:
# NOTE(woosuk): self.max_position_embeddings is the original
# maximum length before applying the rope scaling.
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
base = self.base * (
(self.scaling_factor * max_len / self.max_position_embeddings) -
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> float:
return (dim * math.log(max_position_embeddings /
(num_rotations * 2 * math.pi))) / (2 *
math.log(base))
# Find dim range bounds based on rotations
def _yarn_find_correction_range(
low_rot: int,
high_rot: int,
dim: int,
base: float = 10000,
max_position_embeddings: int = 2048) -> Tuple[int, int]:
low = math.floor(
_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings))
high = math.ceil(
_yarn_find_correction_dim(high_rot, dim, base,
max_position_embeddings))
return max(low, 0), min(high, dim - 1) # Clamp values just in case
def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype) -> torch.Tensor:
if low == high:
high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
def _yarn_get_mscale(scale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
class YaRNScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: int = 32,
beta_slow: int = 1,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(
_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
return cache
class Phi3SuScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
long_mscale: float = 1.225,
):
super().__init__()
if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
return inv_freq
def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)
query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin
return query.flatten(-2), key.flatten(-2)
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
def get_rope(
head_size: int,
rotary_dim: int,
max_position: int,
base: int,
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
rope_scaling_args)
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
if scaling_type != "su":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
is_neox_style,
scaling_factor)
elif scaling_type == "dynamic":
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style,
scaling_factor)
elif scaling_type == "yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow")
}
rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
base, is_neox_style,
scaling_factor,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
return rotary_emb

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,155 @@
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
vocab_range_from_global_vocab_size(
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
s += f", embedding_dim={self.embedding_dim}"
s += f", org_vocab_size={self.org_vocab_size}"
s += f', num_embeddings_padded={self.num_embeddings_padded}'
s += f', tp_size={self.tp_size}'
return s
class ParallelLMHead(VocabParallelEmbedding):
"""Parallelized LM head.
Output logits weight matrices used in the Sampler. The weight and bias
tensors are padded to make sure they are divisible by the number of
model parallel GPUs.
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
"weight_loader": self.weight_loader
})
else:
self.register_parameter("bias", None)
def forward(self, input_):
del input_
raise RuntimeError("LMHead's weights should be used in the sampler.")