This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

View File

View File

@@ -0,0 +1,71 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
Base class for Mamba-like layers which support the v1 engine.
Inherit from this class if you implement a custom layer.
"""
# Contains the KV cache (mamba state) for the layer
# in the shape specified by `self.get_state_shape`.
kv_cache: tuple[torch.Tensor, ...]
@abstractmethod
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
"""
Defines the shape of the state.
For mamba layers this is usually a (conv_state, ssm_state) tuple.
In this case, returns (conv_state_shape, ssm_state_shape).
"""
pass
@property
@abstractmethod
def mamba_type(self) -> str:
pass
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
@abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = vllm_config.cache_config.mamba_block_size
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec(
shapes=self.get_state_shape(),
dtypes=self.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
),
)

View File

@@ -0,0 +1,402 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.lightning_attn import (
lightning_attention,
linear_decode_forward_triton,
)
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = (x * self.weight).to(orig_dtype)
return x
def forward(
self,
x: torch.Tensor,
residual: torch.Tensor | None = None,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: int | None = None,
**kwargs,
) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(
q, k, v, slope_rate, block_size=block_size, kv_history=kv_history
)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import LinearAttentionBackend
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim
)
def __init__(
self,
hidden_size: int,
hidden_inner_size: int,
num_heads: int,
head_dim: int,
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.BLOCK = block_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.total_num_heads = num_heads
self.hidden_inner_size = hidden_inner_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.total_num_heads % self.tp_size == 0
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.output_gate = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.output_gate",
)
self.out_proj = RowParallelLinear(
self.hidden_inner_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = MiniMaxText01RMSNormTP(
self.hidden_inner_size,
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (
1 - layer_idx / (num_hidden_layer - 1) + 1e-5
)
self.tp_slope = self.slope_rate[
self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads
].contiguous()
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@staticmethod
def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
slopes = torch.tensor(
get_slopes(n_attention_heads), dtype=torch.float32
).reshape(n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
offset = attn_metadata.num_decode_tokens
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx,
)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden.insert(0, hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(
q, k, v, kv_cache, self.tp_slope, slot_id, 32
)
return hidden
def forward(
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
) -> None:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = (
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
)
else:
num_actual_tokens = hidden_states.shape[0]
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx
]
q_end = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx + 1
]
query_len = q_end - q_start
context_len = (
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
- query_len
)
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx
]
kv_cache[block_to_clear, ...] = 0
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
hidden = torch.empty(
(q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype
)
else:
if not decode_only:
hidden = self._prefill_and_mix_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
else:
hidden = self._decode_infer(
q, k, v, kv_cache, state_indices_tensor, attn_metadata
)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
output[:num_actual_tokens], _ = self.out_proj(hidden)
def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states, output=output, positions=positions)
def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
)

View File

@@ -0,0 +1,535 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING, NamedTuple
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
from torch.nn.parameter import Parameter
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn,
selective_state_update,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer")
class MambaMixer(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_has_weight: bool = True,
rms_norm_eps: float = 1e-5,
activation="silu",
is_lora_enabled: bool = False,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
prefix: str = "",
):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation
self.is_lora_enabled = is_lora_enabled
self.conv_kernel_size = conv_kernel_size
self.intermediate_size = intermediate_size
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=intermediate_size,
bias=use_conv_bias,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, bias=use_bias
)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
time_step_rank + ssm_state_size * 2,
bias=False,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
)
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
tp_rank
]
)
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))
tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
intermediate_size // tp_size,
ssm_state_size,
dtype=torch.float32,
)
)
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
set_weight_attrs(self.D, {"weight_loader": weight_loader})
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
)
self.dt_layernorm = (
RMSNorm(
time_step_rank,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
self.b_layernorm = (
RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
self.c_layernorm = (
RMSNorm(
ssm_state_size,
eps=rms_norm_eps,
has_weight=rms_norm_has_weight,
)
if use_rms_norm
else None
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The inner tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def _ssm_transform(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if self.is_lora_enabled:
# Lora kernel requires contiguous tensor.
ssm_params = self.x_proj(x.contiguous())[0]
else:
ssm_params = self.x_proj(x)[0]
time_step, B, C = torch.split(
ssm_params,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
return discrete_time_step, B, C
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
torch.ops.vllm.mamba_mixer(
hidden_states,
output,
self.prefix,
)
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
pass
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
"""
Run the Mamba-1 SSM pipeline.
Steps
-----
1. Apply the gated-MLP linear projection to the raw input.
2. Pass the projected sequence through the convolutional mixing layer.
3. Feed the result into the State-Space Model (SSM) blocks.
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
to produce contextual representations.
5. Project the contextualised sequence back
to the output embedding dimension.
Batch handling
--------------
Prefill and decode tokens are processed by dedicated CUDA
kernels for both the convolutional (conv1d) and SSM stages.
In the case of a mixed batch (containing both prefill and
decode tokens), both sets of kernels are executed independently
and their outputs are concatenated before the final output projection.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
query_start_loc_p = attn_metadata.query_start_loc_p
state_indices_tensor = attn_metadata.state_indices_tensor
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states_p = attn_metadata.has_initial_states_p
num_padded_decodes = attn_metadata.num_padded_decodes
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
if attn_metadata is None:
# V1 profile run
hidden_states_BC = hidden_states_BC.contiguous()
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
has_prefill = num_prefill_tokens > 0
has_decode = num_decode_tokens > 0
num_actual_tokens = num_prefill_tokens + num_decode_tokens
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
gate,
state_indices_tensor,
num_prefill_tokens,
num_prefills,
num_padded_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
gate_p = prefill_decode_split.gate_p
gate_d = prefill_decode_split.gate_d
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
if prefix_caching_enabled:
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_first_scheduled_token_p = (
attn_metadata.block_idx_first_scheduled_token_p
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
ssm_outputs = []
if has_prefill:
# 2. Convolution sequence transformation
conv_out_p = causal_conv1d_fn(
hidden_states_BC_p,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
query_start_loc=query_start_loc_p,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
num_computed_tokens=num_computed_tokens_p,
block_size_to_align=mamba_block_size,
)
# 3. State Space Model sequence transformations.
discrete_time_step_p, B_p, C_p = self._ssm_transform(
conv_out_p.transpose(-2, -1)
)
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_out_p = selective_scan_fn(
conv_out_p,
ssm_state,
discrete_time_step_p,
self.A,
B_p.transpose(-2, -1),
C_p.transpose(-2, -1),
self.D.float(),
gate_p,
time_proj_bias,
delta_softplus=True,
cache_indices=state_indices_tensor_p,
has_initial_state=has_initial_states_p,
query_start_loc=query_start_loc_p,
block_size=mamba_block_size,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
)
ssm_outputs.append(scan_out_p)
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, block_idx_last_scheduled_token_d.unsqueeze(1)
).squeeze(1)
else:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d
# 2. Convolution sequence transformation
conv_out_d = causal_conv1d_update(
hidden_states_BC_d.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
).transpose(0, 1)
# 3. State Space Model sequence transformation.
discrete_time_step_d, B_d, C_d = self._ssm_transform(
conv_out_d.transpose(-2, -1)
)
time_proj_bias = self._time_proj_bias()
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
selective_state_update(
ssm_state,
conv_out_d.transpose(0, 1),
discrete_time_step_d.transpose(0, 1),
self.A,
B_d,
C_d,
self.D,
gate_d.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=scan_outputs_d,
)
scan_outputs_d = scan_outputs_d.transpose(0, 1)
ssm_outputs.insert(0, scan_outputs_d)
scan_outputs_combined = (
ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
)
# 5. Final output projection
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
out = self.out_proj(scan_outputs_combined)[0]
else:
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
output[:num_actual_tokens] = out
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba1_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba1_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.intermediate_size,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
@property
def mamba_type(self) -> str:
return "mamba1"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionBackend
return Mamba1AttentionBackend
def _time_proj_bias(self) -> torch.Tensor | None:
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
return self.dt_proj.bias.float()
return None
class PrefillDecodeSplit(NamedTuple):
hidden_states_BC_p: torch.Tensor
hidden_states_BC_d: torch.Tensor
gate_p: torch.Tensor
gate_d: torch.Tensor
state_indices_tensor_p: torch.Tensor
state_indices_tensor_d: torch.Tensor
def split_batch_to_prefill_and_decode(
hidden_states_BC: torch.Tensor,
gate: torch.Tensor,
state_indices_tensor: torch.Tensor,
num_prefill_tokens: int,
num_prefills: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1,
)
gate_d, gate_p = torch.split(
gate[..., :num_actual_tokens], [num_padded_decodes, num_prefill_tokens], dim=-1
)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[: num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
dim=0,
)
return PrefillDecodeSplit(
hidden_states_BC_p=hidden_states_BC_p,
hidden_states_BC_d=hidden_states_BC_d,
gate_p=gate_p,
gate_d=gate_d,
state_indices_tensor_p=state_indices_tensor_p,
state_indices_tensor_d=state_indices_tensor_d,
)
def mamba_mixer(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output)
def mamba_mixer_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mamba_mixer",
op_func=mamba_mixer,
mutates_args=["output"],
fake_impl=mamba_mixer_fake,
)

View File

@@ -0,0 +1,928 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from torch import nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined_varlen,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction,
composed_weight_loader,
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
@CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert self.full_hidden_size % self.tp_size == 0, (
"Tensor parallel world size must divide hidden size."
)
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
def mamba_v2_sharded_weight_loader(
shard_spec: list[tuple[int, int, float]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction:
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed
together with it.
"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
# - track boundary of (sharded) param, and loaded_weight, respectively
boundary, loaded_boundary = 0, 0
# - iterate over the shard specs
for full_dim, extra, duplicate_groups in shard_spec:
# - full dim is the model dim (before TP).
# - extra > 0, means there is expected overall increase
# of dimensions. This is so because of replication.
# - ratio is used map the tp_rank to the actual shard
# rank. This is useful when there is replication of
# groups to accompany head shards.
# - size of the loaded shard
shard_size = full_dim // tp_size
# - compute the rank into the loaded shard.
# - if there is replication, different TP shards will
# take from the same rank.
# NOTE: currently we only support duplication
# in the case where num_groups == 1
rank = 0 if duplicate_groups else tp_rank
# - leftmost boundary index into loaded weight.
loaded_skip = rank * shard_size
loaded_start_idx = loaded_boundary + loaded_skip
# - take these many dims from the loaded weight.
take = min(shard_size, full_dim - extra - loaded_skip)
# - always shard on dim 0
# - the ignore is for a mundane mypy error as it does not
# seem to handle slices well.
# https://github.com/python/mypy/issues/2410
param.data[
boundary : (boundary + take), ... # type: ignore[misc]
] = loaded_weight[
loaded_start_idx : (
loaded_start_idx + take
) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries
boundary += shard_size
loaded_boundary += full_dim - extra
return loader
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer2")
class MambaMixer2(MambaBase, CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""
def __init__(
self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
# For TP, the sharding plan is as follows:
# - for the conv modules, since
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
# we shard intermediate_size and n_groups
# - since intermediate_size = n_heads * head_dim, sharding on
# intermediate_size is achieved by sharding on n_heads.
# - IF, world_size divides groups, then sharding
# (n_groups / world_size, n_heads / world_size)
# also maintains the invariant n_heads % n_groups == 0
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
# to allocate extra space in the shard, such that groups
# may be replicated to follow the head shard.
# - NOTE: currently for the world size DOES NOT divide groups
# case, we only support the case when n_groups == 1
self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
assert num_heads % self.tp_size == 0, (
"Tensor parallel world size must divide num heads."
)
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
"If tensor parallel world size does not divide num_groups, "
"then num_groups must equal 1."
)
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
self.activation = activation
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
self.n_groups = n_groups
if n_groups % self.tp_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size
)
self.n_groups = n_groups + groups
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
if n_groups % self.tp_size == 0:
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
# for conv1d.bias, covn1d.weight and in_proj.weight
# - need to set these settings, to assign the groups
# to the head shards
group_shard_settings = (
self.groups_ssm_state_size, # expected model size
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
n_groups == 1, # if there was only one group
)
intermediate_settings = (intermediate_size, 0, False)
head_settings = (self.num_heads, 0, False)
# - the weight already has a "weight_loader" attribute
# which set_weight_attrs will raise if we do not
# delete before trying to override it
# - ditto for the other two weights below
delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs(
self.conv1d.bias,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs(
self.conv1d.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings,
group_shard_settings,
group_shard_settings,
],
self.tp_size,
tp_rank,
)
},
)
if quant_config is None:
# - quant layers do not have a weight loader
delattr(self.in_proj.weight, "weight_loader")
set_weight_attrs(
self.in_proj.weight,
{
"weight_loader": mamba_v2_sharded_weight_loader(
[
intermediate_settings, # for gate
intermediate_settings,
group_shard_settings,
group_shard_settings,
head_settings, # for dt
],
self.tp_size,
tp_rank,
)
},
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
# and `set_weight_attrs` doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
# - these are TPed by heads to reduce the size of the
# temporal shape
self.A = nn.Parameter(
torch.empty(
divide(num_heads, self.tp_size),
dtype=torch.float32,
)
)
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader(
sharded_weight_loader(0), lambda x: -torch.exp(x.float())
)
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = Mixer2RMSNormGated(
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
pass
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: torch.Tensor | None = None,
):
forward_context = get_forward_context()
# attn_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
attn_metadata: AttentionMetadata = forward_context.attn_metadata
assert self.cache_config is not None
mamba_block_size = self.cache_config.mamba_block_size
prefix_caching_enabled = self.cache_config.enable_prefix_caching
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
prep_initial_states = attn_metadata.prep_initial_states
chunk_size = attn_metadata.chunk_size
seq_idx_p = attn_metadata.seq_idx_p
query_start_loc_p = attn_metadata.query_start_loc_p
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split(
projected_states,
[
self.intermediate_size // self.tp_size,
self.conv_dim // self.tp_size,
self.num_heads // self.tp_size,
],
dim=-1,
)
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
# - get hidden_states, B and C after depthwise convolution.
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
hidden_states_B_C,
[
self.intermediate_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
self.groups_ssm_state_size // self.tp_size,
],
dim=-1,
)
if attn_metadata is None:
# profile run
hidden_states_B_C = (
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
).contiguous()
hidden_states, _B, _C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states = self.norm(hidden_states, gate)
out, _ = self.out_proj(hidden_states)
return out
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0,
)
if prefix_caching_enabled:
# If prefix caching is enabled, retrieve the relevant variables
# for prefill and decode
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
torch.split(
attn_metadata.block_idx_last_computed_token,
[num_decodes, num_prefills],
dim=0,
)
)
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
torch.split(
attn_metadata.block_idx_last_scheduled_token,
[num_decodes, num_prefills],
dim=0,
)
)
# Prefill-only variables:
block_idx_first_scheduled_token_p = (
attn_metadata.block_idx_first_scheduled_token_p
)
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
else:
block_idx_last_computed_token_d = None
block_idx_last_computed_token_p = None
block_idx_last_scheduled_token_d = None
block_idx_last_scheduled_token_p = None
block_idx_first_scheduled_token_p = None
num_computed_tokens_p = None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
num_prefill_tokens + num_decodes,
(self.num_heads // self.tp_size) * self.head_dim,
],
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
preallocated_ssm_out,
[num_decodes, num_prefill_tokens],
dim=0,
)
# Process prefill requests
if has_prefill:
# 2. Convolution sequence transformation
# - It will read the initial states for every sequence,
# that has "has_initial_states_p" == True,
# from "cache_indices", using "state_indices_tensor_p".
# - It updates the "conv_state" cache in positions pointed
# to by "state_indices_tensor_p".
# In particular, it will always write the state at the
# sequence end.
# In addition, "block_idx_first_scheduled_token_p" and
# "block_idx_last_scheduled_token_p"
# are provided (which are pointers into
# "state_indices_tensor_p"), it will write additional cache
# states aligned at "block_size_to_align".
x = hidden_states_B_C_p.transpose(
0, 1
) # this is the form that causal-conv see
hidden_states_B_C_p = causal_conv1d_fn(
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
initial_state_idx=block_idx_last_computed_token_p,
num_computed_tokens=num_computed_tokens_p,
block_size_to_align=mamba_block_size,
metadata=attn_metadata,
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if has_initial_states_p is not None and prep_initial_states:
kernel_ssm_indices = state_indices_tensor_p
if prefix_caching_enabled:
kernel_ssm_indices = state_indices_tensor_p.gather(
1, block_idx_last_computed_token_p.unsqueeze(1)
).squeeze(1)
initial_states = torch.where(
has_initial_states_p[:, None, None, None],
ssm_state[kernel_ssm_indices],
0,
)
# NOTE: final output is an in-place update of out tensor
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
),
dt_p,
self.A,
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
seq_idx=seq_idx_p,
cu_seqlens=query_start_loc_p,
cu_chunk_seqlens=cu_chunk_seqlen_p,
last_chunk_indices=last_chunk_indices_p,
initial_states=initial_states,
return_intermediate_states=prefix_caching_enabled,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
state_dtype=ssm_state.dtype,
)
if prefix_caching_enabled:
# The chunk_stride is the number of chunks per mamba block
# e.g., if mamba_block_size = 512 and chunk_size = 256,
# then chunk_stride = 2
chunk_stride = mamba_block_size // chunk_size
# Save state for sequences with more than just final state
for seq_idx in range(num_prefills):
# Block index for the first scheduled token
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
seq_idx
]
# Block index for the last scheduled token
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
seq_idx
]
# Number of blocks that need to be written
n_blocks_to_fill = (
block_idx_last_scheduled_token - block_idx_first_scheduled_token
)
# Skip sequences that don't have any blocks to fill
if n_blocks_to_fill == 0:
continue
# Look up the state indices
cache_blocks_to_fill = state_indices_tensor_p[
seq_idx,
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
]
# First chunk index for this sequence
if seq_idx == 0:
first_chunk = 0
else:
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]
# First chunk that is aligned on the mamba block boundary
first_aligned_chunk = first_chunk + chunk_stride - 1
# Calculate the number of computed tokens that were not
# already cached
num_unaligned_computed_tokens = (
num_computed_tokens_p[seq_idx] % mamba_block_size
)
if num_unaligned_computed_tokens > 0:
# If the number of computed tokens is not block aligned,
# then we need to shift the index accordingly
first_aligned_chunk -= (
num_unaligned_computed_tokens // chunk_size
)
# Get states to write
from_where = varlen_states[
first_aligned_chunk : first_aligned_chunk
+ n_blocks_to_fill * chunk_stride : chunk_stride
]
# Write the states
ssm_state[cache_blocks_to_fill] = from_where
# For all seqs, store the last state (note: might be partial):
ssm_state[
state_indices_tensor_p.gather(
1, block_idx_last_scheduled_token_p.unsqueeze(1)
).squeeze(1)
] = varlen_states[last_chunk_indices_p]
else:
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate)
# tensor
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode:
if prefix_caching_enabled:
state_indices_tensor_d_input = state_indices_tensor_d.gather(
1, block_idx_last_computed_token_d.unsqueeze(1)
).squeeze(1)
state_indices_tensor_d_output = state_indices_tensor_d.gather(
1, block_idx_last_scheduled_token_d.unsqueeze(1)
).squeeze(1)
# for decode:
# block_idx_first_scheduled_token_d ==
# block_idx_last_scheduled_token_d
# at block boundaries:
# block_idx_first_scheduled_token_d >
# block_idx_last_computed_token_d
else:
# Without caching, read and write in-place to the same blocks:
state_indices_tensor_d_input = state_indices_tensor_d
state_indices_tensor_d_output = state_indices_tensor_d
# 2. Convolution sequence transformation
hidden_states_B_C_d = causal_conv1d_update(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor_d,
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
initial_state_idx=block_idx_last_computed_token_d,
)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A_d = (
self.A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.mamba2_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
self.cache_config.mamba_ssm_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=self.intermediate_size,
tp_world_size=get_tensor_model_parallel_world_size(),
n_groups=self.n_groups,
num_heads=self.num_heads,
head_dim=self.head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
@property
def mamba_type(self) -> str:
return "mamba2"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionBackend
return Mamba2AttentionBackend
def mamba_mixer2(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: torch.Tensor | None = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output, mup_vector=mup_vector)
def mamba_mixer2_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: torch.Tensor | None = None,
) -> None:
return
direct_register_custom_op(
op_name="mamba_mixer2",
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
)

View File

@@ -0,0 +1,225 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config.cache import MambaDType
from vllm.config.model import ModelDType
from vllm.distributed import divide
from vllm.utils.torch_utils import (
STR_DTYPE_TO_TORCH_DTYPE,
get_kv_cache_torch_dtype,
)
class MambaStateDtypeCalculator:
@classmethod
def linear_attention_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
# TODO (tdoublep) requires testing
if mamba_cache_dtype == "float32":
raise ValueError("fp32 state for minimax is not yet supported")
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype,)
@classmethod
def mamba1_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
)
@classmethod
def mamba2_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
return cls._mamba_state_dtype(
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
)
@classmethod
def _mamba_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
mamba_ssm_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
if mamba_ssm_cache_dtype == "auto":
temporal_state_dtype = conv_state_dtype
else:
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
return (conv_state_dtype, temporal_state_dtype)
@classmethod
def short_conv_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, ...]:
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (conv_state_dtype,)
@classmethod
def gated_delta_net_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
) -> tuple[torch.dtype, torch.dtype]:
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype)
@classmethod
def kda_state_dtype(
cls,
model_dtype: ModelDType | torch.dtype,
mamba_cache_dtype: MambaDType,
):
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
return (state_dtype, state_dtype, state_dtype, torch.float32)
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@classmethod
def gated_delta_net_state_shape(
cls,
tp_world_size: int,
num_k_heads: int,
num_v_heads: int,
head_k_dim: int,
head_v_dim: int,
conv_kernel_size: int,
num_spec: int = 0,
):
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
conv_state_shape = (
divide(conv_dim, tp_world_size),
conv_kernel_size - 1 + num_spec,
)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
temporal_state_shape = (
divide(num_v_heads, tp_world_size),
head_k_dim,
head_v_dim,
)
return conv_state_shape, temporal_state_shape
@classmethod
def kda_state_shape(
cls,
tp_world_size: int,
num_heads: int,
head_dim: int,
num_k_heads: int | None = None,
head_k_dim: int | None = None,
conv_kernel_size: int = 4,
num_spec: int = 0,
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
if num_k_heads is None:
num_k_heads = num_heads
if head_k_dim is None:
head_k_dim = head_dim
proj_size = num_heads * head_dim
proj_k_size = num_k_heads * head_k_dim
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
return (
conv_state_shape,
conv_state_k_shape,
conv_state_k_shape,
recurrent_state_shape,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,172 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao.
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
import torch
from vllm.triton_utils import tl, triton
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the other branch
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row: tl.int64,
stride_y_row: tl.int64,
stride_z_row: tl.int64,
M: tl.int64, # number of rows in X
N: tl.int64, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
NORM_BEFORE_GATE: tl.constexpr,
IS_RMS_NORM: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
group = tl.program_id(1)
X += row * stride_x_row + group * N
Y += row * stride_y_row + group * N
if HAS_Z:
Z += row * stride_z_row + group * N
if not IS_RMS_NORM:
Mean += group * M
Rstd += group * M
W += group * N
if HAS_BIAS:
B += group * N
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_Z and not NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
x *= z * tl.sigmoid(z)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
y = x_hat * w + b if HAS_BIAS else x_hat * w
if HAS_Z and NORM_BEFORE_GATE:
z = tl.load(Z + cols, mask=mask).to(tl.float32)
y *= z * tl.sigmoid(z)
# Write output
tl.store(Y + cols, y, mask=mask)
def _layer_norm_fwd(
x,
weight,
bias,
eps,
z=None,
out=None,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
M, N = x.shape
if group_size is None:
group_size = N
assert N % group_size == 0
ngroups = N // group_size
assert x.stride(-1) == 1
if z is not None:
assert z.stride(-1) == 1
assert z.shape == (M, N)
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
mean = (
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
if not is_rms_norm
else None
)
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
if group_size > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
grid = (M, ngroups)
with torch.cuda.device(x.device.index):
_layer_norm_fwd_1pass_kernel[grid](
x,
out,
weight,
bias,
z,
mean,
rstd,
x.stride(0),
out.stride(0),
z.stride(0) if z is not None else 0,
M,
group_size,
eps,
BLOCK_N=BLOCK_N,
NORM_BEFORE_GATE=norm_before_gate,
IS_RMS_NORM=is_rms_norm,
num_warps=num_warps,
)
return out, mean, rstd
def rms_norm_gated(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, _, _ = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=True,
)
return y.reshape(x_shape_og)

View File

@@ -0,0 +1,478 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
import torch
from packaging import version
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON, tl, triton
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
if TRITON3:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
return dt
else:
@triton.jit
def softplus(dt):
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
return dt
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
@triton.heuristics(
{
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
is not None
}
)
@triton.heuristics(
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
)
@triton.jit
def _selective_scan_update_kernel(
# Pointers to matrices
state_ptr,
x_ptr,
dt_ptr,
dt_bias_ptr,
A_ptr,
B_ptr,
C_ptr,
D_ptr,
z_ptr,
out_ptr,
state_batch_indices_ptr,
dst_state_batch_indices_ptr,
pad_slot_id,
# Matrix dimensions
batch,
nheads,
dim,
dstate,
nheads_ngroups_ratio,
# Strides
stride_state_batch,
stride_state_head,
stride_state_dim,
stride_state_dstate,
stride_x_batch,
stride_x_head,
stride_x_dim,
stride_dt_batch,
stride_dt_head,
stride_dt_dim,
stride_dt_bias_head,
stride_dt_bias_dim,
stride_A_head,
stride_A_dim,
stride_A_dstate,
stride_B_batch,
stride_B_group,
stride_B_dstate,
stride_C_batch,
stride_C_group,
stride_C_dstate,
stride_D_head,
stride_D_dim,
stride_z_batch,
stride_z_head,
stride_z_dim,
stride_out_batch,
stride_out_head,
stride_out_dim,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
TIE_HDIM: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
HAS_D: tl.constexpr,
HAS_Z: tl.constexpr,
HAS_STATE_BATCH_INDICES: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
# is the same as the batch id.
if HAS_STATE_BATCH_INDICES:
dst_state_batch_indices_ptr += pid_b
dst_state_batch_idx = tl.load(dst_state_batch_indices_ptr).to(tl.int64)
dst_state_ptr = state_ptr + (
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
)
state_batch_indices_ptr += pid_b
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
else:
dst_state_ptr = (
state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head
)
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
if HAS_DT_BIAS:
dt_bias_ptr += pid_h * stride_dt_bias_head
A_ptr += pid_h * stride_A_head
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
if HAS_Z:
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
state_ptrs = state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
dst_state_ptrs = dst_state_ptr + (
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
)
x_ptrs = x_ptr + offs_m * stride_x_dim
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
if HAS_DT_BIAS:
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
if HAS_D:
D_ptr += pid_h * stride_D_head
A_ptrs = A_ptr + (
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
)
B_ptrs = B_ptr + offs_n * stride_B_dstate
C_ptrs = C_ptr + offs_n * stride_C_dstate
if HAS_D:
D_ptrs = D_ptr + offs_m * stride_D_dim
if HAS_Z:
z_ptrs = z_ptr + offs_m * stride_z_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
state = tl.load(state_ptrs, mask=mask, other=0.0)
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if not TIE_HDIM:
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
).to(tl.float32)
dA = tl.exp(A * dt[:, None])
else:
dt = tl.load(dt_ptr).to(tl.float32)
if HAS_DT_BIAS:
dt += tl.load(dt_bias_ptr).to(tl.float32)
if DT_SOFTPLUS:
dt = softplus(dt)
A = tl.load(A_ptr).to(tl.float32)
dA = tl.exp(A * dt) # scalar, not a matrix
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
if HAS_D:
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
if HAS_Z:
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
state = state * dA + dB * x[:, None]
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
if HAS_STATE_BATCH_INDICES:
mask &= state_batch_idx != pad_slot_id
tl.store(dst_state_ptrs, state, mask=mask)
out = tl.sum(state * C[None, :], axis=1)
if HAS_D:
out += x * D
if HAS_Z:
out *= z * tl.sigmoid(z)
tl.store(out_ptrs, out, mask=offs_m < dim)
def selective_state_update(
state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
state_batch_indices=None,
dst_state_batch_indices=None,
pad_slot_id=PAD_SLOT_ID,
out=None,
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padded
entries that will not be processed,
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
in this case, the kernel will not process entries at
indices 0 and 3
out: Preallocated ssm output tensor. Assume same shape as x.
In-place updated.
"""
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
if out.dim() == 2:
out = out.unsqueeze(1)
_, nheads, dim, dstate = state.shape
batch = x.shape[0]
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
if state_batch_indices is not None:
assert state_batch_indices.shape == (batch,)
if dst_state_batch_indices is not None:
assert dst_state_batch_indices.shape == (batch,)
else:
# revert to the default behavior of in-place state updates
dst_state_batch_indices = state_batch_indices
assert out.shape == x.shape
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0
and dt.stride(-1) == 0
and dt_bias.stride(-1) == 0
)
with torch.cuda.device(x.device.index):
_selective_scan_update_kernel[grid](
state,
x,
dt,
dt_bias,
A,
B,
C,
D,
z,
out,
state_batch_indices,
dst_state_batch_indices,
pad_slot_id,
batch,
nheads,
dim,
dstate,
nheads // ngroups,
state.stride(0),
state.stride(1),
state.stride(2),
state.stride(3),
x.stride(0),
x.stride(1),
x.stride(2),
dt.stride(0),
dt.stride(1),
dt.stride(2),
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
A.stride(0),
A.stride(1),
A.stride(2),
B.stride(0),
B.stride(1),
B.stride(2),
C.stride(0),
C.stride(1),
C.stride(2),
*(D.stride(0), D.stride(1)) if D is not None else 0,
z_strides[0],
z_strides[1],
z_strides[2],
out.stride(0),
out.stride(1),
out.stride(2),
dt_softplus,
tie_hdim,
BLOCK_SIZE_M,
num_warps=num_warps,
)
def selective_scan_fn(
u,
ssm_states,
delta,
A,
B,
C,
D=None,
z=None,
delta_bias=None,
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None,
pad_slot_id=PAD_SLOT_ID,
block_size=1024,
block_idx_first_scheduled_token=None,
block_idx_last_scheduled_token=None,
initial_state_idx=None,
) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
applies changes in place.
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
applies changes in place.
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
A: (dim, dstate)
B: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
C: (ngroups, dstate, total_length) for varlen or
(batch,ngroups,dstate,seqlen)
D: (dim,)
z: (dim, total_length) for varlen or (batch, dim, seqlen)
dt_bias: (dim,) or (dim)
query_start_loc: (batch + 1) int32
The cumulative sequence lengths of the sequences in
the batch, used to index into sequence. prepended with 0.
for example: query_start_loc = torch.Tensor([0,10,16,17]),
x.shape=(dim,17)
cache_indices: (batch) int32
A tensor with each cell is a correspondent
input and output ssm_state indices
- Without APC: (batch,) - single state index per batch item
- With APC: (batch, max_positions) - cache block indices for read/write
Each non-zero value indicates a cache block to load from and/or write to.
has_initial_state: (batch) bool
A tensor populated with ones and zeros,
indicate if the ssm_state at the corresponding index should be
used as initial state. Not providing argument assumes
there's no initial state
pad_slot_id: int
if cache_indices is passed, lets the kernel identify padding entries
that will not be processed,
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
in this case, the kernel will not process entries at indices 0 and 3
block_size: int
The block size to align the cached states to
block_idx_first_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the first
cache block to be filled is located.
block_idx_last_scheduled_token: (batch,), dtype int32
The pointer into cache_indices, where the last cache block
to be filled is located.
initial_state_idx: (batch,), dtype int32
The pointer into cache_indices, where the cache block
containing the initial state is located.
returns
output: (dim, total_length) for varlen or (batch, dim, seqlen)
supports inplace replacement
"""
if u.stride(-1) != 1:
u = u.contiguous()
if delta.stride(-1) != 1:
delta = delta.contiguous()
if D is not None:
D = D.contiguous()
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if z is not None and z.stride(-1) != 1:
z = z.contiguous()
if B.dim() == 3 and query_start_loc is None:
B = B.unsqueeze(1)
if B.dim() == 2 and query_start_loc is not None:
B = B.unsqueeze(0)
if C.dim() == 3 and query_start_loc is None:
C = C.unsqueeze(1)
if C.dim() == 2 and query_start_loc is not None:
C = C.unsqueeze(0)
ops.selective_scan_fwd(
u,
delta,
A,
B,
C,
D,
z,
delta_bias,
delta_softplus,
query_start_loc,
cache_indices,
has_initial_state,
ssm_states,
pad_slot_id,
block_size,
block_idx_first_scheduled_token,
block_idx_last_scheduled_token,
initial_state_idx,
)
if z is None:
return delta # output written inplace to delta
else:
return z # output written inplace to z

View File

@@ -0,0 +1,211 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
# ruff: noqa: E501,SIM102
import torch
from vllm.triton_utils import tl, triton
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["chunk_size", "K", "IS_CAUSAL"],
)
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
out_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
seqlen,
chunk_size: tl.constexpr,
K: tl.constexpr,
ngroups: tl.constexpr,
stride_a_seqlen: tl.int64,
stride_a_head: tl.int64,
stride_ak: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_bk: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_outm: tl.int64,
stride_outn: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
dot_dtype: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_ch = tl.program_id(axis=1).to(tl.int64)
pid_c = pid_ch // ngroups
pid_h = pid_ch - pid_c * ngroups
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
if IS_CAUSAL:
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
return
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# compute a * b.T
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(
a_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
).to(dot_dtype)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
& (offs_n[None, :] < chunk_size_limit),
other=0.0,
).to(dot_dtype)
acc += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out = acc.to(out_ptr.dtype.element_ty)
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
tl.store(
out_ptrs,
out,
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
)
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
"""
Argument:
a: (seqlen, ngroups, k)
b: (seqlen, ngroups, k)
chunk_size: int
cu_chunk_seq_lens: (nchunks+1,)
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
guaranteed to be correct.
Return:
out: (nchunks, ngroups, chunk_size, chunk_size)
"""
seqlen, ngroups, k = a.shape
assert b.shape == a.shape
if a.stride(-1) != 1 and a.stride(0) != 1:
a = a.contiguous()
if b.stride(-1) != 1 and b.stride(0) != 1:
b = b.contiguous()
nchunks = len(cu_chunk_seqlens) - 1
# Allocates output.
out_dtype = a.dtype if output_dtype is None else output_dtype
out = torch.empty(
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
)
dot_dtype = (
tl.bfloat16
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
else (
tl.float16
if a.dtype == torch.float16 or b.dtype == torch.float16
else tl.float32
)
)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
nchunks * ngroups,
)
with torch.cuda.device(a.device.index):
_bmm_chunk_fwd_kernel[grid](
a_ptr=a,
b_ptr=b,
out_ptr=out,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
seqlen=seqlen,
chunk_size=chunk_size,
K=k,
ngroups=ngroups,
stride_a_seqlen=a.stride(0),
stride_a_head=a.stride(1),
stride_ak=a.stride(2),
stride_b_seqlen=b.stride(0),
stride_b_head=b.stride(1),
stride_bk=b.stride(2),
stride_out_chunk=out.stride(0),
stride_out_head=out.stride(1),
stride_outm=out.stride(-2),
stride_outn=out.stride(-1),
IS_CAUSAL=causal,
dot_dtype=dot_dtype,
)
return out

View File

@@ -0,0 +1,456 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
# ruff: noqa: E501,SIM102
from packaging import version
from vllm.triton_utils import tl, triton
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
)
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices
cb_ptr,
x_ptr,
z_ptr,
out_ptr,
dt_ptr,
dA_cumsum_ptr,
seq_idx_ptr,
C_ptr,
states_ptr,
D_ptr,
initstates_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
chunk_size: tl.constexpr,
hdim: tl.constexpr,
dstate: tl.constexpr,
seqlen,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_cb_chunk: tl.int64,
stride_cb_head: tl.int64,
stride_cb_csize_m: tl.int64,
stride_cb_csize_k: tl.constexpr,
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_z_seqlen: tl.int64,
stride_z_head: tl.int64,
stride_z_hdim: tl.constexpr,
stride_out_seqlen: tl.int64,
stride_out_head: tl.int64,
stride_out_hdim: tl.constexpr,
stride_dt_chunk: tl.int64,
stride_dt_head: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_head: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
stride_C_seqlen: tl.int64,
stride_C_head: tl.int64,
stride_C_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
stride_D_head: tl.constexpr,
# Meta-parameters
IS_CAUSAL: tl.constexpr,
HAS_D: tl.constexpr,
D_HAS_HDIM: tl.constexpr,
HAS_Z: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
BLOCK_SIZE_DSTATE: tl.constexpr,
IS_TRITON_22: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
C_ptr += (
chunk_seqlen_start * stride_C_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
)
# M-block offsets and prev states
# - logic in next block may override these if there is an active offset
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
seq_idx_ptr += pid_c * stride_seq_idx_chunk
seq_idx = tl.load(seq_idx_ptr)
seq_idx_prev = tl.load(
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
)
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states_ptr = (
initstates_ptr
+ seq_idx * stride_init_states_batch
+ pid_h * stride_init_states_head
)
prev_states_hdim = stride_init_states_hdim
prev_states_dstate = stride_init_states_dstate
else:
prev_states_ptr = (
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
)
prev_states_hdim = stride_states_hdim
prev_states_dstate = stride_states_dstate
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
dA_cs_m = tl.load(
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
).to(tl.float32)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
offs_k_dstate = tl.arange(
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
)
C_ptrs = C_ptr + (
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
)
scale_m = tl.exp(dA_cs_m)
if BLOCK_SIZE_DSTATE <= 128:
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
# if no init states AND starting a new sequence, we need zeros
prev_states = tl.zeros(
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
)
else:
# otherwise read the previous state
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc = tl.dot(C, prev_states) * scale_m[:, None]
else:
prev_states_ptrs = (
prev_states_ptr
+ offs_n[None, :] * prev_states_hdim
+ offs_k_dstate[:, None] * prev_states_dstate
)
for k in range(0, dstate, BLOCK_SIZE_K):
C = tl.load(
C_ptrs,
mask=(offs_m[:, None] < chunk_size_limit)
& (offs_k_dstate[None, :] < dstate - k),
other=0.0,
)
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
prev_states = tl.zeros(
(BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
)
else:
prev_states = tl.load(
prev_states_ptrs,
mask=(offs_k_dstate[:, None] < dstate - k)
& (offs_n[None, :] < hdim),
other=0.0,
)
prev_states = prev_states.to(C_ptr.dtype.element_ty)
acc += tl.dot(C, prev_states)
C_ptrs += BLOCK_SIZE_K
prev_states_ptrs += BLOCK_SIZE_K
acc *= scale_m[:, None]
offs_k = tl.arange(0, BLOCK_SIZE_K)
cb_ptrs = cb_ptr + (
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
)
x_ptrs = x_ptr + (
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
K_MAX = (
chunk_size_limit
if not IS_CAUSAL
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
)
for k in range(0, K_MAX, BLOCK_SIZE_K):
cb = tl.load(
cb_ptrs,
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
tl.float32
)
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
# So we don't need masking wrt seq_idx here.
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
cb *= dt_k
if IS_CAUSAL:
mask = offs_m[:, None] >= k + offs_k[None, :]
cb = tl.where(mask, cb, 0.0)
cb = cb.to(x_ptr.dtype.element_ty)
x = tl.load(
x_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
other=0.0,
)
acc += tl.dot(cb, x)
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if HAS_D:
if D_HAS_HDIM:
D = tl.load(
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
).to(tl.float32)
else:
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
x_residual = tl.load(
x_ptr
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc += x_residual * D
if HAS_Z:
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
z_ptrs = z_ptr + (
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
)
z = tl.load(
z_ptrs,
mask=(offs_out_m[:, None] < chunk_size_limit)
& (offs_out_n[None, :] < hdim),
other=0.0,
).to(tl.float32)
acc *= z * tl.sigmoid(z)
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
out_ptrs = out_ptr + (
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
)
tl.store(
out_ptrs,
acc,
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
)
def _chunk_scan_fwd(
cb,
x,
dt,
dA_cumsum,
C,
states,
cu_chunk_seqlens,
out,
seq_idx,
D=None,
z=None,
initial_states=None,
):
assert seq_idx is not None, "this implementation requires seq_idx"
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = C.shape
assert nheads % ngroups == 0
assert C.shape == (seqlen, ngroups, dstate)
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if z is not None:
assert z.shape == x.shape
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
assert states.shape == (nchunks, nheads, headdim, dstate)
assert seq_idx.shape == (nchunks,)
grid = lambda META: (
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
nchunks,
nheads,
)
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
_chunk_scan_fwd_kernel[grid](
cb_ptr=cb,
x_ptr=x,
z_ptr=z,
out_ptr=out,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
seq_idx_ptr=seq_idx,
C_ptr=C,
states_ptr=states,
D_ptr=D,
initstates_ptr=initial_states,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
chunk_size=chunk_size,
hdim=headdim,
dstate=dstate,
seqlen=seqlen,
nheads_ngroups_ratio=nheads // ngroups,
stride_cb_chunk=cb.stride(0),
stride_cb_head=cb.stride(1),
stride_cb_csize_m=cb.stride(2),
stride_cb_csize_k=cb.stride(3),
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_z_seqlen=z_strides[0],
stride_z_head=z_strides[1],
stride_z_hdim=z_strides[2],
stride_out_seqlen=out.stride(0),
stride_out_head=out.stride(1),
stride_out_hdim=out.stride(2),
stride_dt_chunk=dt.stride(1),
stride_dt_head=dt.stride(0),
stride_dt_csize=dt.stride(2),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_seq_idx_chunk=seq_idx.stride(0),
stride_C_seqlen=C.stride(0),
stride_C_head=C.stride(1),
stride_C_dstate=C.stride(2),
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
stride_D_head=D.stride(0) if D is not None else 0,
IS_CAUSAL=True,
HAS_D=D is not None,
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
HAS_Z=z is not None,
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
IS_TRITON_22=TRITON_22,
HAS_INITSTATES=initial_states is not None,
)
return

View File

@@ -0,0 +1,700 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
from .mamba_ssm import softplus
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE_H": 2}),
triton.Config({"BLOCK_SIZE_H": 4}),
triton.Config({"BLOCK_SIZE_H": 8}),
triton.Config({"BLOCK_SIZE_H": 16}),
triton.Config({"BLOCK_SIZE_H": 32}),
triton.Config({"BLOCK_SIZE_H": 64}),
],
key=["chunk_size", "nheads"],
)
@triton.jit
def _chunk_cumsum_fwd_kernel(
# Pointers to matrices
dt_ptr,
A_ptr,
dt_bias_ptr,
dt_out_ptr,
dA_cumsum_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimension
seqlen,
nheads: tl.constexpr,
chunk_size: tl.constexpr,
dt_min: tl.constexpr,
dt_max: tl.constexpr,
# Strides
stride_dt_seqlen: tl.int64,
stride_dt_head: tl.constexpr,
stride_A_head: tl.constexpr,
stride_dt_bias_head: tl.constexpr,
stride_dt_out_head: tl.int64,
stride_dt_out_chunk: tl.int64,
stride_dt_out_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
# Meta-parameters
DT_SOFTPLUS: tl.constexpr,
HAS_DT_BIAS: tl.constexpr,
BLOCK_SIZE_H: tl.constexpr,
BLOCK_SIZE_CHUNK: tl.constexpr,
):
# if dt is long, may cause problems, so use 64 bit
# https://github.com/triton-lang/triton/issues/1058
pid_c = tl.program_id(axis=0).to(tl.int64)
pid_h = tl.program_id(axis=1)
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
dt_ptr += chunk_seqlen_start * stride_dt_seqlen
dt_out_ptr += pid_c * stride_dt_out_chunk
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
dt_ptrs = dt_ptr + (
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
)
A_ptrs = A_ptr + offs_h * stride_A_head
dt_out_ptrs = dt_out_ptr + (
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
)
dA_cs_ptrs = dA_cumsum_ptr + (
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
)
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
dt = tl.load(
dt_ptrs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
other=0.0,
).to(tl.float32)
if HAS_DT_BIAS:
dt_bias = tl.load(
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
).to(tl.float32)
dt += dt_bias[:, None]
if DT_SOFTPLUS:
dt = tl.where(dt <= 20.0, softplus(dt), dt)
dt = tl.clamp(dt, dt_min, dt_max)
dt = tl.where(
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
)
tl.store(
dt_out_ptrs,
dt,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
)
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
dA = dt * A[:, None]
dA_cs = tl.cumsum(dA, axis=1)
tl.store(
dA_cs_ptrs,
dA_cs,
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_fwd_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
states_ptr,
dt_ptr,
dA_cumsum_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
seqlen,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
pid_c = tl.program_id(axis=1).to(tl.int64)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
b_ptr += (
chunk_seqlen_start * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
tl.float32
)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
num_stages=3,
num_warps=8,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=4,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=5,
num_warps=2,
),
triton.Config(
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
num_stages=4,
num_warps=2,
),
],
key=["hdim", "dstate", "chunk_size"],
)
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices
x_ptr,
b_ptr,
dt_ptr,
dA_cumsum_ptr,
chunk_states_ptr,
cu_seqlens_ptr,
states_ptr,
initstates_ptr,
# Matrix dimensions
hdim: tl.constexpr,
dstate: tl.constexpr,
chunk_size: tl.constexpr,
nheads_ngroups_ratio: tl.constexpr,
# Strides
stride_x_seqlen: tl.int64,
stride_x_head: tl.int64,
stride_x_hdim: tl.constexpr,
stride_b_seqlen: tl.int64,
stride_b_head: tl.int64,
stride_b_dstate: tl.constexpr,
stride_dt_head: tl.int64,
stride_dt_chunk: tl.int64,
stride_dt_csize: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_chunk_states_chunk: tl.int64,
stride_chunk_states_head: tl.int64,
stride_chunk_states_hdim: tl.int64,
stride_chunk_states_dstate: tl.constexpr,
stride_states_batch: tl.int64,
stride_states_head: tl.int64,
stride_states_hdim: tl.int64,
stride_states_dstate: tl.constexpr,
stride_init_states_batch: tl.int64,
stride_init_states_head: tl.int64,
stride_init_states_hdim: tl.int64,
stride_init_states_dstate: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
HAS_INITSTATES: tl.constexpr,
):
pid_b = tl.program_id(axis=1)
pid_h = tl.program_id(axis=2)
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
pid_m = tl.program_id(axis=0) // num_pid_n
pid_n = tl.program_id(axis=0) % num_pid_n
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
pid_c = (end_idx - 1) // chunk_size
b_ptr += (
pid_c * chunk_size * stride_b_seqlen
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
)
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
chunk_states_ptr += (
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
)
if HAS_INITSTATES:
# if there are init states provided, we differentiate between states (which
# are boundary conditions at a chunk boundary) and initstates (which are boundary
# conditions when a new example in a cont batch starts)
initstates_ptr += pid_h * stride_init_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
x_ptrs = x_ptr + (
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
)
b_ptrs = b_ptr + (
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
)
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
dA_cs_last = tl.load(
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
chunk_size_limit = end_idx - pid_c * chunk_size
start_idx = tl.load(cu_seqlens_ptr + pid_b)
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
x = tl.load(
x_ptrs,
mask=(offs_m[:, None] < hdim)
& (offs_k[None, :] < chunk_size_limit - k)
& (offs_k[None, :] >= start_idx_cur - k),
other=0.0,
)
b = tl.load(
b_ptrs,
mask=(offs_k[:, None] < chunk_size_limit - k)
& (offs_n[None, :] < dstate)
& (offs_k[:, None] >= start_idx_cur - k),
other=0.0,
).to(tl.float32)
dA_cs_k = tl.load(
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
).to(tl.float32)
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
tl.float32
)
scale = tl.where(
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
0.0,
)
b *= scale[:, None]
b = b.to(x_ptr.dtype.element_ty)
acc += tl.dot(x, b)
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
# If HAS_INITSTATES==True need to consider two possibilities
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
# - if state_idx >= pid * chunk_size, then we need to insert initstates
if (
(start_idx < pid_c * chunk_size) # first chunk
or (HAS_INITSTATES)
):
dA_cs_boundary = 0.0 # default
if not HAS_INITSTATES:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
# - this seems repetitive, buts its to help the compiler
if start_idx < pid_c * chunk_size:
past_states_ptrs = chunk_states_ptr + (
offs_m[:, None] * stride_chunk_states_hdim
+ offs_n[None, :] * stride_chunk_states_dstate
)
else:
past_states_ptrs = initstates_ptr + (
pid_b * stride_init_states_batch
+ offs_m[:, None] * stride_init_states_hdim
+ offs_n[None, :] * stride_init_states_dstate
)
# need to adjust the boundary
if start_idx > pid_c * chunk_size:
dA_cs_boundary = tl.load(
dA_cumsum_ptr
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
).to(tl.float32)
past_states = tl.load(
past_states_ptrs,
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
other=0.0,
).to(tl.float32)
scale = tl.exp(dA_cs_last - dA_cs_boundary)
acc += past_states * scale
states = acc.to(states_ptr.dtype.element_ty)
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
states_ptrs = states_ptr + (
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
)
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
tl.store(states_ptrs, states, mask=c_mask)
def _chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
):
seqlen, nheads = dt.shape
assert A.shape == (nheads,)
if dt_bias is not None:
assert dt_bias.shape == (nheads,)
nchunks = cu_chunk_seqlens.shape[0] - 1
dt_out = torch.empty(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
dA_cumsum = torch.empty(
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
)
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
with torch.cuda.device(dt.device.index):
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
dt_ptr=dt,
A_ptr=A,
dt_bias_ptr=dt_bias,
dt_out_ptr=dt_out,
dA_cumsum_ptr=dA_cumsum,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
seqlen=seqlen,
nheads=nheads,
chunk_size=chunk_size,
dt_min=dt_limit[0],
dt_max=dt_limit[1],
stride_dt_seqlen=dt.stride(0),
stride_dt_head=dt.stride(1),
stride_A_head=A.stride(0),
stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0,
stride_dt_out_head=dt_out.stride(0),
stride_dt_out_chunk=dt_out.stride(1),
stride_dt_out_csize=dt_out.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
DT_SOFTPLUS=dt_softplus,
HAS_DT_BIAS=dt_bias is not None,
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
)
return dA_cumsum, dt_out
def _chunk_state_fwd(
B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True
):
seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
if states is not None:
assert states.shape == (nchunks, nheads, headdim, dstate)
else:
states_dtype = torch.float32 if states_in_fp32 else B.dtype
states = torch.empty(
(nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype
)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
nchunks,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_fwd_kernel[grid](
x_ptr=x,
b_ptr=B,
states_ptr=states,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
seqlen=seqlen,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
)
return states
def chunk_state_varlen(
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
):
total_seqlen, nheads, headdim = x.shape
_, nchunks, chunk_size = dt.shape
_, ngroups, dstate = B.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
assert nheads % ngroups == 0
assert B.shape == (total_seqlen, ngroups, dstate)
assert dt.shape == (nheads, nchunks, chunk_size)
assert dA_cumsum.shape == dt.shape
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
if initial_states is not None:
assert initial_states.shape == (batch, nheads, headdim, dstate)
states = torch.empty(
batch,
nheads,
headdim,
dstate,
dtype=chunk_states.dtype,
device=chunk_states.device,
)
initial_states_strides = (
(
initial_states.stride(0),
initial_states.stride(1),
initial_states.stride(2),
initial_states.stride(3),
)
if initial_states is not None
else (0, 0, 0, 0)
)
grid = lambda META: (
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
batch,
nheads,
)
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x_ptr=x,
b_ptr=B,
dt_ptr=dt,
dA_cumsum_ptr=dA_cumsum,
chunk_states_ptr=chunk_states,
cu_seqlens_ptr=cu_seqlens,
states_ptr=states,
initstates_ptr=initial_states,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
nheads_ngroups_ratio=nheads // ngroups,
stride_x_seqlen=x.stride(0),
stride_x_head=x.stride(1),
stride_x_hdim=x.stride(2),
stride_b_seqlen=B.stride(0),
stride_b_head=B.stride(1),
stride_b_dstate=B.stride(2),
stride_dt_head=dt.stride(0),
stride_dt_chunk=dt.stride(1),
stride_dt_csize=dt.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_chunk_states_chunk=chunk_states.stride(0),
stride_chunk_states_head=chunk_states.stride(1),
stride_chunk_states_hdim=chunk_states.stride(2),
stride_chunk_states_dstate=chunk_states.stride(3),
stride_states_batch=states.stride(0),
stride_states_head=states.stride(1),
stride_states_hdim=states.stride(2),
stride_states_dstate=states.stride(3),
stride_init_states_batch=initial_states_strides[0],
stride_init_states_head=initial_states_strides[1],
stride_init_states_hdim=initial_states_strides[2],
stride_init_states_dstate=initial_states_strides[3],
HAS_INITSTATES=initial_states is not None,
)
return states

View File

@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
# ruff: noqa: E501
import torch
from einops import rearrange
from packaging import version
from vllm.triton_utils import triton
from .ssd_bmm import _bmm_chunk_fwd
from .ssd_chunk_scan import _chunk_scan_fwd
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
from .ssd_state_passing import _state_passing_fwd
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
def is_int_pow_2(n):
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
def _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
return_intermediate_states=False,
seq_idx=None,
cu_seqlens=None,
cu_chunk_seqlens=None,
last_chunk_indices=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
state_dtype=None,
):
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
seqlen, nheads, headdim = x.shape
_, ngroups, dstate = B.shape
assert nheads % ngroups == 0
assert B.shape == (seqlen, ngroups, dstate)
assert dt.shape == (seqlen, nheads)
assert A.shape == (nheads,)
assert C.shape == B.shape
if z is not None:
assert z.shape == x.shape
if D is not None:
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
if seq_idx is not None:
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
if B.stride(-1) != 1:
B = B.contiguous()
if C.stride(-1) != 1:
C = C.contiguous()
if (
x.stride(-1) != 1 and x.stride(0) != 1
): # Either M or K dimension should be contiguous
x = x.contiguous()
if (
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
): # Either M or K dimension should be contiguous
z = z.contiguous()
if D is not None and D.stride(-1) != 1:
D = D.contiguous()
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
if initial_states is not None:
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
# This function executes 5 sub-functions for computing mamba
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
# which has a minimal implementation to understand the below operations
# - as explained by the blog, mamba is a special case of causal attention
# - the idea is to chunk the attention matrix and compute each
# submatrix separately using different optimizations.
# - see the blog and paper for a visualization of the submatrices
# which we refer to in the comments below
# 1. Compute chunked cumsum of A * dt
# - here dt may go through a softplus activation
dA_cumsum, dt = _chunk_cumsum_fwd(
dt,
A,
chunk_size,
cu_chunk_seqlens,
dt_bias=dt_bias,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
states = _chunk_state_fwd(
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states and
# ii) seq_idx to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum, # (nheads, nchunks, chunk_size)
cu_chunk_seqlens,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None
else None, # (batch, nheads, headdim*dstate)
seq_idx=seq_idx,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
)
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
# 4. Compute batched matrix multiply for C_j^T B_i terms
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
# 5. Scan and compute the diagonal blocks, taking into
# account past causal states.
# - if initial states are provided, then states information will be
# augmented with initial_states.
# - to do this properly, we need to account for example changes in
# the continuous batch, therefore we introduce pseudo chunks, which is
# a chunk that is split up each time an example changes.
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
# a seq_idx change, in which case we take states information from
# init_states.
_chunk_scan_fwd(
CB,
x,
dt,
dA_cumsum,
C,
states,
cu_chunk_seqlens,
out, # in-place update
seq_idx,
D=D,
z=z,
initial_states=initial_states,
)
if return_intermediate_states:
return states
else:
return states[last_chunk_indices]
def mamba_chunk_scan_combined_varlen(
x,
dt,
A,
B,
C,
chunk_size,
cu_seqlens,
cu_chunk_seqlens,
last_chunk_indices,
seq_idx,
out,
D=None,
z=None,
dt_bias=None,
initial_states=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_intermediate_states=False,
state_dtype=None,
):
"""
Argument:
x: (seqlen, nheads, headdim)
dt: (seqlen, nheads)
A: (nheads)
B: (seqlen, ngroups, dstate)
C: (seqlen, ngroups, dstate)
chunk_size: int
cu_seqlens: (batch + 1,)
cu_chunk_seqlens: (nchunks + 1,)
last_chunk_indices: (batch,)
seq_idx: (nchunks,)
out: (seqlen, nheads, headdim) preallocated output tensor
D: (nheads, headdim) or (nheads,)
z: (seqlen, nheads, headdim)
dt_bias: (nheads,)
initial_states: (batch, nheads, headdim, dstate)
dt_softplus: Whether to apply softplus to dt
out: (seqlen, nheads, headdim) preallocated output tensor
state_dtype: The data type of the ssm state
Return:
varlen_states: (batch, nheads, headdim, dstate)
"""
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
assert seq_idx is not None
varlen_states = _mamba_chunk_scan_combined_fwd(
x,
dt,
A,
B,
C,
chunk_size,
out,
D=D,
z=z,
dt_bias=dt_bias,
initial_states=initial_states,
return_intermediate_states=return_intermediate_states,
seq_idx=seq_idx,
cu_seqlens=cu_seqlens,
cu_chunk_seqlens=cu_chunk_seqlens,
last_chunk_indices=last_chunk_indices,
dt_softplus=dt_softplus,
dt_limit=dt_limit,
state_dtype=state_dtype,
)
return varlen_states

View File

@@ -0,0 +1,157 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Copyright (c) 2024, Tri Dao, Albert Gu.
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
# ruff: noqa: E501
import torch
from vllm.triton_utils import tl, triton
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
],
key=["dim"],
)
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices
states_ptr,
out_ptr,
dA_cs_ptr,
initstates_ptr,
seq_idx_ptr,
cu_chunk_seqlens_ptr,
# Matrix dimensions
dim: tl.constexpr,
nchunks,
seqlen,
chunk_size: tl.constexpr,
# Strides
stride_states_chunk: tl.int64,
stride_states_head: tl.int64,
stride_states_dim: tl.constexpr,
stride_out_chunk: tl.int64,
stride_out_head: tl.int64,
stride_out_dim: tl.constexpr,
stride_dA_cs_head: tl.int64,
stride_dA_cs_chunk: tl.int64,
stride_dA_cs_csize: tl.constexpr,
stride_initstates_batch: tl.int64,
stride_initstates_head: tl.int64,
stride_initstates_dim: tl.constexpr,
stride_seq_idx_chunk: tl.constexpr,
# Meta-parameters
HAS_INITSTATES: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_h = tl.program_id(axis=1)
pid_m = tl.program_id(axis=0)
states_ptr += pid_h * stride_states_head
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
out_ptr += pid_h * stride_out_head
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
states_ptrs = states_ptr + offs_m * stride_states_dim
out_ptrs = out_ptr + offs_m * stride_out_dim
if HAS_INITSTATES:
initstates_ptrs = (
initstates_ptr
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
else:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = 0
for c in range(nchunks):
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
# we have started a new sequence
if prev_seq_idx != seq_idx:
if HAS_INITSTATES:
initstates_ptrs = (
initstates_ptr
+ seq_idx * stride_initstates_batch
+ pid_h * stride_initstates_head
+ offs_m * stride_initstates_dim
)
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
tl.float32
)
else:
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
prev_seq_idx = seq_idx
states = tl.exp(dA_cs) * states + new_states
tl.store(out_ptrs, states, mask=offs_m < dim)
states_ptrs += stride_states_chunk
dA_cs_ptr += stride_dA_cs_chunk
out_ptrs += stride_out_chunk
def _state_passing_fwd(
states,
dA_cumsum,
cu_chunk_seqlens,
seq_idx,
initial_states=None,
out_dtype=None,
):
nchunks, nheads, dim = states.shape
chunk_size = dA_cumsum.shape[-1]
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
seqlen = seq_idx.shape[-1]
out_dtype = states.dtype if out_dtype is None else out_dtype
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
initial_states_strides = (
(initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
if initial_states is not None
else (0, 0, 0)
)
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
with torch.cuda.device(states.device.index):
_state_passing_fwd_kernel[grid](
states_ptr=states,
out_ptr=out,
dA_cs_ptr=dA_cumsum,
initstates_ptr=initial_states,
seq_idx_ptr=seq_idx,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
dim=dim,
nchunks=nchunks,
seqlen=seqlen if seq_idx is not None else 0,
chunk_size=chunk_size if seq_idx is not None else 0,
stride_states_chunk=states.stride(0),
stride_states_head=states.stride(1),
stride_states_dim=states.stride(2),
stride_out_chunk=out.stride(0),
stride_out_head=out.stride(1),
stride_out_dim=out.stride(2),
stride_dA_cs_head=dA_cumsum.stride(0),
stride_dA_cs_chunk=dA_cumsum.stride(1),
stride_dA_cs_csize=dA_cumsum.stride(2),
stride_initstates_batch=initial_states_strides[0],
stride_initstates_head=initial_states_strides[1],
stride_initstates_dim=initial_states_strides[2],
stride_seq_idx_chunk=seq_idx.stride(0),
HAS_INITSTATES=initial_states is not None,
)
return out

View File

@@ -0,0 +1,264 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator,
MambaStateShapeCalculator,
)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
@CustomOp.register("short_conv")
class ShortConv(MambaBase, CustomOp):
def __init__(
self,
config,
dim: int,
layer_idx: int,
model_config: ModelConfig | None = None,
cache_config: CacheConfig | None = None,
prefix: str = "",
):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.conv_dim = dim
self.L_cache = config.conv_L_cache
self.bias = config.conv_bias
self.conv = ColumnParallelLinear(
input_size=self.L_cache,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.conv1d",
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
self.in_proj = MergedColumnParallelLinear(
input_size=dim,
output_sizes=[dim] * 3,
bias=self.bias,
prefix=f"{prefix}.in_proj",
)
self.out_proj = RowParallelLinear(
input_size=dim,
output_size=dim,
bias=self.bias,
prefix=f"{prefix}.out_proj",
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.kv_cache = (torch.tensor([]),)
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
return
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
torch.ops.vllm.short_conv(
hidden_states,
output,
self.prefix,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
):
forward_context = get_forward_context()
# ShortConvAttentionMetadata contains metadata necessary for the
# short_conv triton kernels to operate in continuous batching and in
# chunked prefill modes; they are computed at top-level model forward
# since they stay the same and reused for all mamba layers in the same
# iteration.
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
state_indices_tensor = attn_metadata.state_indices_tensor
has_initial_states_p = attn_metadata.has_initial_states_p
BCx, _ = self.in_proj(hidden_states)
B, C, x = BCx.chunk(3, dim=-1)
conv_weights = self.conv.weight.view(
self.conv.weight.size(0), self.conv.weight.size(2)
)
if attn_metadata is None:
# V1 profile run
Bx = (B * x).contiguous()
hidden_states = C * Bx
contextualized_states, _ = self.out_proj(hidden_states)
return contextualized_states
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_decodes + num_prefill_tokens
# NOTE: V1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
# Split along token dimension
B_d, B_p = torch.split(
B[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
C_d, C_p = torch.split(
C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
x_d, x_p = torch.split(
x[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
if has_prefill
else None
)
conv_output_list = []
if has_prefill:
Bx_p = (B_p * x_p).transpose(0, 1)
Bx = causal_conv1d_fn(
Bx_p,
conv_weights,
self.conv.bias,
activation=None,
conv_states=conv_state,
has_initial_state=has_initial_states_p,
cache_indices=state_indices_tensor_p,
metadata=attn_metadata,
query_start_loc=query_start_loc_p,
).transpose(0, 1)[:num_prefill_tokens]
y = C_p * Bx
conv_output_list.append(y)
if has_decode:
Bx_d = (B_d * x_d).contiguous()
Bx = causal_conv1d_update(
Bx_d,
conv_state,
conv_weights,
self.conv.bias,
activation=None,
conv_state_indices=state_indices_tensor_d,
)
y = C_d * Bx
conv_output_list.insert(0, y)
# Merge prefill and decode outputs before passing to gated MLP
hidden_states = torch.vstack(conv_output_list)
# Final linear projection
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.short_conv_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...]]:
return MambaStateShapeCalculator.short_conv_state_shape(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.conv_dim,
conv_kernel=self.L_cache,
)
@property
def mamba_type(self) -> str:
return "short_conv"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionBackend
return ShortConvAttentionBackend
def short_conv(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states, output=output)
def short_conv_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="short_conv",
op_func=short_conv,
mutates_args=["output"],
fake_impl=short_conv_fake,
)