Sync from v0.13
This commit is contained in:
0
vllm/model_executor/layers/mamba/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/__init__.py
Normal file
68
vllm/model_executor/layers/mamba/abstract.py
Normal file
68
vllm/model_executor/layers/mamba/abstract.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import get_mamba_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
|
||||
|
||||
class MambaBase(AttentionLayerBase):
|
||||
"""
|
||||
Base class for Mamba-like layers which support the v1 engine.
|
||||
Inherit from this class if you implement a custom layer.
|
||||
"""
|
||||
|
||||
# Contains the KV cache (mamba state) for the layer
|
||||
# in the shape specified by `self.get_state_shape`.
|
||||
kv_cache: tuple[torch.Tensor, ...]
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||
"""
|
||||
Defines the shape of the state.
|
||||
For mamba layers this is usually a (conv_state, ssm_state) tuple.
|
||||
In this case, returns (conv_state_shape, ssm_state_shape).
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mamba_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||
pass
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
if (
|
||||
vllm_config.speculative_config is not None
|
||||
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet."
|
||||
)
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
|
||||
return MambaSpec(
|
||||
shapes=self.get_state_shape(),
|
||||
dtypes=self.get_state_dtype(),
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=self.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
return get_mamba_attn_backend(self.mamba_type)
|
||||
388
vllm/model_executor/layers/mamba/linear_attn.py
Normal file
388
vllm/model_executor/layers/mamba/linear_attn.py
Normal file
@@ -0,0 +1,388 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
lightning_attention,
|
||||
linear_decode_forward_triton,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
name = "MiniMaxText01RMSNormTP"
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.tp_world = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.tp_world)))
|
||||
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.variance_epsilon = eps
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> None:
|
||||
tp_world = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = loaded_weight.shape[0] // tp_world
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
return
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = (x * self.weight).to(orig_dtype)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
@staticmethod
|
||||
def jit_linear_forward_prefix(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
layer_idx: int | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
slope_rate = slope_rate.to(torch.float32)
|
||||
should_pad_dim = q.dim() == 3
|
||||
if should_pad_dim:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
b, h, n, d = q.shape
|
||||
e = d
|
||||
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
|
||||
output, kv_history = lightning_attention(
|
||||
q, k, v, slope_rate, block_size=block_size, kv_history=kv_history
|
||||
)
|
||||
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
|
||||
assert output.shape[0] == 1, "batch size must be 1"
|
||||
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
||||
|
||||
|
||||
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=self.num_heads, tp_size=self.tp_size, head_dim=self.head_dim
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
hidden_inner_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
max_position: int,
|
||||
block_size: int,
|
||||
num_hidden_layer: int,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
layer_idx: int = 0,
|
||||
linear_layer_idx: int = 0,
|
||||
prefix: str = "linear_attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.BLOCK = block_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.total_num_heads = num_heads
|
||||
self.hidden_inner_size = hidden_inner_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert self.total_num_heads % self.tp_size == 0
|
||||
self.tp_heads = self.total_num_heads // self.tp_size
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
self.tp_hidden = self.head_dim * self.tp_heads
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.hidden_inner_size * 3,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.output_gate = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.hidden_inner_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output_gate",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.hidden_inner_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.norm = MiniMaxText01RMSNormTP(
|
||||
self.hidden_inner_size,
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(self.num_heads)
|
||||
if num_hidden_layer <= 1:
|
||||
self.slope_rate = slope_rate * (1 + 1e-5)
|
||||
else:
|
||||
self.slope_rate = slope_rate * (
|
||||
1 - layer_idx / (num_hidden_layer - 1) + 1e-5
|
||||
)
|
||||
self.tp_slope = self.slope_rate[
|
||||
self.tp_rank * self.tp_heads : (self.tp_rank + 1) * self.tp_heads
|
||||
].contiguous()
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
@staticmethod
|
||||
def weight_direct_load(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
def get_slopes(n):
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(n))
|
||||
return (
|
||||
get_slopes_power_of_2(closest_power_of_2)
|
||||
+ get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
|
||||
)
|
||||
|
||||
slopes = torch.tensor(
|
||||
get_slopes(n_attention_heads), dtype=torch.float32
|
||||
).reshape(n_attention_heads, 1, 1)
|
||||
return slopes
|
||||
|
||||
def _prefill_and_mix_infer(
|
||||
self, q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
|
||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
self.tp_slope,
|
||||
self.BLOCK,
|
||||
layer_idx=self.layer_idx,
|
||||
)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor, attn_metadata):
|
||||
q = q[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[: attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[: attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(
|
||||
q, k, v, kv_cache, self.tp_slope, slot_id, 32
|
||||
)
|
||||
return hidden
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
|
||||
) -> None:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(
|
||||
self, hidden_states: torch.Tensor, output: torch.Tensor, positions: torch.Tensor
|
||||
) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
num_actual_tokens = (
|
||||
attn_metadata.num_prefill_tokens + attn_metadata.num_decode_tokens
|
||||
)
|
||||
else:
|
||||
num_actual_tokens = hidden_states.shape[0]
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
|
||||
qkv32 = qkv.to(torch.float32)
|
||||
qkvact = torch.nn.functional.silu(qkv32)
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens", 0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
q_end = attn_metadata.query_start_loc[
|
||||
num_decode_tokens + prefill_idx + 1
|
||||
]
|
||||
query_len = q_end - q_start
|
||||
context_len = (
|
||||
attn_metadata.seq_lens[num_decode_tokens + prefill_idx]
|
||||
- query_len
|
||||
)
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[
|
||||
num_decode_tokens + prefill_idx
|
||||
]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
hidden = torch.empty(
|
||||
(q.shape[0], q.shape[1] * q.shape[2]), device=q.device, dtype=q.dtype
|
||||
)
|
||||
else:
|
||||
if not decode_only:
|
||||
hidden = self._prefill_and_mix_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
else:
|
||||
hidden = self._decode_infer(
|
||||
q, k, v, kv_cache, state_indices_tensor, attn_metadata
|
||||
)
|
||||
hidden = self.norm._forward(hidden)
|
||||
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
||||
hidden = F.sigmoid(gate) * hidden
|
||||
hidden = hidden.to(hidden_states.dtype)
|
||||
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden)
|
||||
|
||||
|
||||
def linear_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states, output=output, positions=positions)
|
||||
|
||||
|
||||
def linear_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="linear_attention",
|
||||
op_func=linear_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=linear_attention_fake,
|
||||
)
|
||||
526
vllm/model_executor/layers/mamba/mamba_mixer.py
Normal file
526
vllm/model_executor/layers/mamba/mamba_mixer.py
Normal file
@@ -0,0 +1,526 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn,
|
||||
selective_state_update,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register("mamba_mixer")
|
||||
class MambaMixer(MambaBase, CustomOp):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||
for why A isn't selective) ∆, B, C are input-dependent
|
||||
(this is a key difference between Mamba and the linear time
|
||||
invariant S4, and is why Mamba is called
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
time_step_rank: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
use_rms_norm: bool,
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.time_step_rank = time_step_rank
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.use_rms_norm = use_rms_norm
|
||||
self.activation = activation
|
||||
self.is_lora_enabled = is_lora_enabled
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=intermediate_size,
|
||||
bias=use_conv_bias,
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2, bias=use_bias
|
||||
)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
time_step_rank + ssm_state_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(
|
||||
time_step_rank, intermediate_size, bias=True, skip_bias_add=True
|
||||
)
|
||||
|
||||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
param.data.copy_(
|
||||
loaded_weight.data.split(loaded_weight.shape[0] // tp_size, dim=0)[
|
||||
tp_rank
|
||||
]
|
||||
)
|
||||
|
||||
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
weight_loader(param, -torch.exp(loaded_weight.float()))
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
intermediate_size // tp_size,
|
||||
ssm_state_size,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
self.dt_layernorm = (
|
||||
RMSNorm(
|
||||
time_step_rank,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
)
|
||||
if use_rms_norm
|
||||
else None
|
||||
)
|
||||
|
||||
self.b_layernorm = (
|
||||
RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
)
|
||||
if use_rms_norm
|
||||
else None
|
||||
)
|
||||
|
||||
self.c_layernorm = (
|
||||
RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
)
|
||||
if use_rms_norm
|
||||
else None
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def _ssm_transform(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.is_lora_enabled:
|
||||
# Lora kernel requires contiguous tensor.
|
||||
ssm_params = self.x_proj(x.contiguous())[0]
|
||||
else:
|
||||
ssm_params = self.x_proj(x)[0]
|
||||
time_step, B, C = torch.split(
|
||||
ssm_params,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
dim=-1,
|
||||
)
|
||||
if self.use_rms_norm:
|
||||
assert self.dt_layernorm is not None
|
||||
assert self.b_layernorm is not None
|
||||
assert self.c_layernorm is not None
|
||||
time_step = self.dt_layernorm(time_step.contiguous())
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
Steps
|
||||
-----
|
||||
1. Apply the gated-MLP linear projection to the raw input.
|
||||
2. Pass the projected sequence through the convolutional mixing layer.
|
||||
3. Feed the result into the State-Space Model (SSM) blocks.
|
||||
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
to produce contextual representations.
|
||||
5. Project the contextualised sequence back
|
||||
to the output embedding dimension.
|
||||
|
||||
Batch handling
|
||||
--------------
|
||||
Prefill and decode tokens are processed by dedicated CUDA
|
||||
kernels for both the convolutional (conv1d) and SSM stages.
|
||||
In the case of a mixed batch (containing both prefill and
|
||||
decode tokens), both sets of kernels are executed independently
|
||||
and their outputs are concatenated before the final output projection.
|
||||
"""
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_BC = hidden_states_BC.contiguous()
|
||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
has_prefill = num_prefill_tokens > 0
|
||||
has_decode = num_decode_tokens > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC,
|
||||
gate,
|
||||
state_indices_tensor,
|
||||
num_prefill_tokens,
|
||||
num_prefills,
|
||||
num_decode_tokens,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
gate_p = prefill_decode_split.gate_p
|
||||
gate_d = prefill_decode_split.gate_d
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
|
||||
if prefix_caching_enabled:
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_scheduled_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
|
||||
block_idx_first_scheduled_token_p = (
|
||||
attn_metadata.block_idx_first_scheduled_token_p
|
||||
)
|
||||
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||
else:
|
||||
block_idx_last_computed_token_d = None
|
||||
block_idx_last_computed_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_scheduled_token_p = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
ssm_outputs = []
|
||||
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_p = causal_conv1d_fn(
|
||||
hidden_states_BC_p,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
num_computed_tokens=num_computed_tokens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
)
|
||||
# 3. State Space Model sequence transformations.
|
||||
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||
conv_out_p.transpose(-2, -1)
|
||||
)
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_out_p = selective_scan_fn(
|
||||
conv_out_p,
|
||||
ssm_state,
|
||||
discrete_time_step_p,
|
||||
self.A,
|
||||
B_p.transpose(-2, -1),
|
||||
C_p.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate_p,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
has_initial_state=has_initial_states_p,
|
||||
query_start_loc=query_start_loc_p,
|
||||
block_size=mamba_block_size,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_scheduled_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
else:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_d = causal_conv1d_update(
|
||||
hidden_states_BC_d.transpose(0, 1),
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
).transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation.
|
||||
discrete_time_step_d, B_d, C_d = self._ssm_transform(
|
||||
conv_out_d.transpose(-2, -1)
|
||||
)
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_outputs_d = torch.empty_like(hidden_states_BC_d.transpose(0, 1))
|
||||
selective_state_update(
|
||||
ssm_state,
|
||||
conv_out_d.transpose(0, 1),
|
||||
discrete_time_step_d.transpose(0, 1),
|
||||
self.A,
|
||||
B_d,
|
||||
C_d,
|
||||
self.D,
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=scan_outputs_d,
|
||||
)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = (
|
||||
ssm_outputs[0] if len(ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
)
|
||||
|
||||
# 5. Final output projection
|
||||
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
|
||||
scan_outputs_combined = scan_outputs_combined.transpose(-2, -1).contiguous()
|
||||
out = self.out_proj(scan_outputs_combined)[0]
|
||||
else:
|
||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||
|
||||
output[:num_actual_tokens] = out
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.intermediate_size,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
def _time_proj_bias(self) -> torch.Tensor | None:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
return None
|
||||
|
||||
|
||||
class PrefillDecodeSplit(NamedTuple):
|
||||
hidden_states_BC_p: torch.Tensor
|
||||
hidden_states_BC_d: torch.Tensor
|
||||
gate_p: torch.Tensor
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
num_prefill_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decode_tokens: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
dim=-1,
|
||||
)
|
||||
gate_d, gate_p = torch.split(
|
||||
gate[..., :num_actual_tokens], [num_decode_tokens, num_prefill_tokens], dim=-1
|
||||
)
|
||||
|
||||
# num_decode_tokens accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[: num_decode_tokens + num_prefills],
|
||||
[num_decode_tokens, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
hidden_states_BC_d=hidden_states_BC_d,
|
||||
gate_p=gate_p,
|
||||
gate_d=gate_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
)
|
||||
|
||||
|
||||
def mamba_mixer(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer",
|
||||
op_func=mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer_fake,
|
||||
)
|
||||
930
vllm/model_executor/layers/mamba/mamba_mixer2.py
Normal file
930
vllm/model_executor/layers/mamba/mamba_mixer2.py
Normal file
@@ -0,0 +1,930 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import selective_state_update
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction,
|
||||
composed_weight_loader,
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
|
||||
@CustomOp.register("mixer2_gated_rms_norm")
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert self.full_hidden_size % self.tp_size == 0, (
|
||||
"Tensor parallel world size must divide hidden size."
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return rms_norm_gated(
|
||||
x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
)
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
shard_spec: list[tuple[int, int, float]],
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> LoaderFunction:
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures that all the groups corresponding to a head shard is placed
|
||||
together with it.
|
||||
"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
# - track boundary of (sharded) param, and loaded_weight, respectively
|
||||
boundary, loaded_boundary = 0, 0
|
||||
|
||||
# - iterate over the shard specs
|
||||
for full_dim, extra, duplicate_groups in shard_spec:
|
||||
# - full dim is the model dim (before TP).
|
||||
# - extra > 0, means there is expected overall increase
|
||||
# of dimensions. This is so because of replication.
|
||||
# - ratio is used map the tp_rank to the actual shard
|
||||
# rank. This is useful when there is replication of
|
||||
# groups to accompany head shards.
|
||||
|
||||
# - size of the loaded shard
|
||||
shard_size = full_dim // tp_size
|
||||
|
||||
# - compute the rank into the loaded shard.
|
||||
# - if there is replication, different TP shards will
|
||||
# take from the same rank.
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0 if duplicate_groups else tp_rank
|
||||
|
||||
# - leftmost boundary index into loaded weight.
|
||||
loaded_skip = rank * shard_size
|
||||
loaded_start_idx = loaded_boundary + loaded_skip
|
||||
|
||||
# - take these many dims from the loaded weight.
|
||||
take = min(shard_size, full_dim - extra - loaded_skip)
|
||||
|
||||
# - always shard on dim 0
|
||||
# - the ignore is for a mundane mypy error as it does not
|
||||
# seem to handle slices well.
|
||||
# https://github.com/python/mypy/issues/2410
|
||||
param.data[
|
||||
boundary : (boundary + take), ... # type: ignore[misc]
|
||||
] = loaded_weight[
|
||||
loaded_start_idx : (
|
||||
loaded_start_idx + take
|
||||
) # type: ignore[misc]
|
||||
] # type: ignore[misc]
|
||||
|
||||
# move indexing boundaries
|
||||
boundary += shard_size
|
||||
loaded_boundary += full_dim - extra
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register("mamba_mixer2")
|
||||
class MambaMixer2(MambaBase, CustomOp):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||
for why A isn't selective) ∆, B, C are input-dependent
|
||||
(this is a key difference between Mamba and the linear time
|
||||
invariant S4, and is why Mamba is called
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# For TP, the sharding plan is as follows:
|
||||
# - for the conv modules, since
|
||||
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
|
||||
# we shard intermediate_size and n_groups
|
||||
# - since intermediate_size = n_heads * head_dim, sharding on
|
||||
# intermediate_size is achieved by sharding on n_heads.
|
||||
# - IF, world_size divides groups, then sharding
|
||||
# (n_groups / world_size, n_heads / world_size)
|
||||
# also maintains the invariant n_heads % n_groups == 0
|
||||
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
|
||||
# to allocate extra space in the shard, such that groups
|
||||
# may be replicated to follow the head shard.
|
||||
# - NOTE: currently for the world size DOES NOT divide groups
|
||||
# case, we only support the case when n_groups == 1
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert num_heads % self.tp_size == 0, (
|
||||
"Tensor parallel world size must divide num heads."
|
||||
)
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||
"If tensor parallel world size does not divide num_groups, "
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
assert (
|
||||
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
|
||||
), (
|
||||
"Tensor parallel currently supported for quantized models only "
|
||||
"if tensor parallel world size divides num groups."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.activation = activation
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.n_groups = n_groups
|
||||
if n_groups % self.tp_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size
|
||||
)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||
|
||||
if n_groups % self.tp_size == 0:
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
else:
|
||||
# This is the n_groups == 1 case,
|
||||
# where we need to duplicate groups if TP>1.
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||
# for conv1d.bias, covn1d.weight and in_proj.weight
|
||||
# - need to set these settings, to assign the groups
|
||||
# to the head shards
|
||||
group_shard_settings = (
|
||||
self.groups_ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
|
||||
n_groups == 1, # if there was only one group
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
head_settings = (self.num_heads, 0, False)
|
||||
|
||||
# - the weight already has a "weight_loader" attribute
|
||||
# which set_weight_attrs will raise if we do not
|
||||
# delete before trying to override it
|
||||
# - ditto for the other two weights below
|
||||
delattr(self.conv1d.bias, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
# - quant layers do not have a weight loader
|
||||
delattr(self.in_proj.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.in_proj.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings, # for gate
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
head_settings, # for dt
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
|
||||
# and `set_weight_attrs` doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
self.register_buffer("conv_weights", conv_weights, persistent=False)
|
||||
|
||||
# - these are TPed by heads to reduce the size of the
|
||||
# temporal shape
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
divide(num_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.use_rms_norm = use_rms_norm
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float())
|
||||
)
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.norm = Mixer2RMSNormGated(
|
||||
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
# Pre-compute sizes for forward pass
|
||||
self.tped_intermediate_size = self.intermediate_size // self.tp_size
|
||||
self.tped_conv_size = self.conv_dim // self.tp_size
|
||||
self.tped_dt_size = self.num_heads // self.tp_size
|
||||
|
||||
self.split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.tped_intermediate_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
mup_vector: torch.Tensor | None = None,
|
||||
):
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
if mup_vector is not None:
|
||||
projected_states = projected_states * mup_vector
|
||||
|
||||
# 2. Prepare inputs for conv + SSM
|
||||
ssm_output = torch.empty(
|
||||
[
|
||||
hidden_states.shape[0],
|
||||
(self.num_heads // self.tp_size) * self.head_dim,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# 3. conv + SSM
|
||||
# (split `projected_states` into hidden_states_B_C, dt in the custom op to
|
||||
# ensure it is not treated as an intermediate tensor by torch compile)
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
projected_states,
|
||||
ssm_output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
gate = projected_states[..., : self.tped_intermediate_size]
|
||||
hidden_states = self.norm(ssm_output, gate)
|
||||
|
||||
# 5. Final linear projection
|
||||
output, _ = self.out_proj(hidden_states)
|
||||
|
||||
return output
|
||||
|
||||
def conv_ssm_forward(
|
||||
self,
|
||||
projected_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
hidden_states_B_C, dt = torch.split(
|
||||
projected_states[..., self.tped_intermediate_size :],
|
||||
[self.tped_conv_size, self.tped_dt_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
forward_context = get_forward_context()
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
assert self.cache_config is not None
|
||||
mamba_block_size = self.cache_config.mamba_block_size
|
||||
prefix_caching_enabled = self.cache_config.enable_prefix_caching
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
cu_chunk_seqlen_p = attn_metadata.cu_chunk_seqlen_p
|
||||
last_chunk_indices_p = attn_metadata.last_chunk_indices_p
|
||||
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states_B_C = (
|
||||
hidden_states_B_C.transpose(0, 1).clone().transpose(0, 1)
|
||||
).contiguous()
|
||||
hidden_states, _B, _C = self.split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
return hidden_states
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# If prefix caching is enabled, retrieve the relevant variables
|
||||
# for prefill and decode
|
||||
block_idx_last_computed_token_d, block_idx_last_computed_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_computed_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
block_idx_last_scheduled_token_d, block_idx_last_scheduled_token_p = (
|
||||
torch.split(
|
||||
attn_metadata.block_idx_last_scheduled_token,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# Prefill-only variables:
|
||||
block_idx_first_scheduled_token_p = (
|
||||
attn_metadata.block_idx_first_scheduled_token_p
|
||||
)
|
||||
num_computed_tokens_p = attn_metadata.num_computed_tokens_p
|
||||
else:
|
||||
block_idx_last_computed_token_d = None
|
||||
block_idx_last_computed_token_p = None
|
||||
block_idx_last_scheduled_token_d = None
|
||||
block_idx_last_scheduled_token_p = None
|
||||
block_idx_first_scheduled_token_p = None
|
||||
num_computed_tokens_p = None
|
||||
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
output[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - It will read the initial states for every sequence,
|
||||
# that has "has_initial_states_p" == True,
|
||||
# from "cache_indices", using "state_indices_tensor_p".
|
||||
# - It updates the "conv_state" cache in positions pointed
|
||||
# to by "state_indices_tensor_p".
|
||||
# In particular, it will always write the state at the
|
||||
# sequence end.
|
||||
# In addition, "block_idx_first_scheduled_token_p" and
|
||||
# "block_idx_last_scheduled_token_p"
|
||||
# are provided (which are pointers into
|
||||
# "state_indices_tensor_p"), it will write additional cache
|
||||
# states aligned at "block_size_to_align".
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1
|
||||
) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
self.conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
block_idx_first_scheduled_token=block_idx_first_scheduled_token_p,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_p,
|
||||
initial_state_idx=block_idx_last_computed_token_p,
|
||||
num_computed_tokens=num_computed_tokens_p,
|
||||
block_size_to_align=mamba_block_size,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
hidden_states_p, B_p, C_p = self.split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p
|
||||
)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
kernel_ssm_indices = state_indices_tensor_p
|
||||
if prefix_caching_enabled:
|
||||
kernel_ssm_indices = state_indices_tensor_p.gather(
|
||||
1, block_idx_last_computed_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[kernel_ssm_indices],
|
||||
0,
|
||||
)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(
|
||||
num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
||||
),
|
||||
dt_p,
|
||||
self.A,
|
||||
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=seq_idx_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
cu_chunk_seqlens=cu_chunk_seqlen_p,
|
||||
last_chunk_indices=last_chunk_indices_p,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=prefix_caching_enabled,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1, self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
if prefix_caching_enabled:
|
||||
# The chunk_stride is the number of chunks per mamba block
|
||||
# e.g., if mamba_block_size = 512 and chunk_size = 256,
|
||||
# then chunk_stride = 2
|
||||
chunk_stride = mamba_block_size // chunk_size
|
||||
|
||||
# Save state for sequences with more than just final state
|
||||
for seq_idx in range(num_prefills):
|
||||
# Block index for the first scheduled token
|
||||
block_idx_first_scheduled_token = block_idx_first_scheduled_token_p[
|
||||
seq_idx
|
||||
]
|
||||
|
||||
# Block index for the last scheduled token
|
||||
block_idx_last_scheduled_token = block_idx_last_scheduled_token_p[
|
||||
seq_idx
|
||||
]
|
||||
|
||||
# Number of blocks that need to be written
|
||||
n_blocks_to_fill = (
|
||||
block_idx_last_scheduled_token - block_idx_first_scheduled_token
|
||||
)
|
||||
|
||||
# Skip sequences that don't have any blocks to fill
|
||||
if n_blocks_to_fill == 0:
|
||||
continue
|
||||
|
||||
# Look up the state indices
|
||||
cache_blocks_to_fill = state_indices_tensor_p[
|
||||
seq_idx,
|
||||
block_idx_first_scheduled_token:block_idx_last_scheduled_token,
|
||||
]
|
||||
|
||||
# First chunk index for this sequence
|
||||
if seq_idx == 0:
|
||||
first_chunk = 0
|
||||
else:
|
||||
first_chunk = 1 + last_chunk_indices_p[seq_idx - 1]
|
||||
|
||||
# First chunk that is aligned on the mamba block boundary
|
||||
first_aligned_chunk = first_chunk + chunk_stride - 1
|
||||
|
||||
# Calculate the number of computed tokens that were not
|
||||
# already cached
|
||||
num_unaligned_computed_tokens = (
|
||||
num_computed_tokens_p[seq_idx] % mamba_block_size
|
||||
)
|
||||
|
||||
if num_unaligned_computed_tokens > 0:
|
||||
# If the number of computed tokens is not block aligned,
|
||||
# then we need to shift the index accordingly
|
||||
first_aligned_chunk -= (
|
||||
num_unaligned_computed_tokens // chunk_size
|
||||
)
|
||||
|
||||
# Get states to write
|
||||
from_where = varlen_states[
|
||||
first_aligned_chunk : first_aligned_chunk
|
||||
+ n_blocks_to_fill * chunk_stride : chunk_stride
|
||||
]
|
||||
|
||||
# Write the states
|
||||
ssm_state[cache_blocks_to_fill] = from_where
|
||||
|
||||
# For all seqs, store the last state (note: might be partial):
|
||||
ssm_state[
|
||||
state_indices_tensor_p.gather(
|
||||
1, block_idx_last_scheduled_token_p.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
] = varlen_states[last_chunk_indices_p]
|
||||
|
||||
else:
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate)
|
||||
# tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
if prefix_caching_enabled:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_computed_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
state_indices_tensor_d_output = state_indices_tensor_d.gather(
|
||||
1, block_idx_last_scheduled_token_d.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
# for decode:
|
||||
# block_idx_first_scheduled_token_d ==
|
||||
# block_idx_last_scheduled_token_d
|
||||
# at block boundaries:
|
||||
# block_idx_first_scheduled_token_d >
|
||||
# block_idx_last_computed_token_d
|
||||
else:
|
||||
# Without caching, read and write in-place to the same blocks:
|
||||
state_indices_tensor_d_input = state_indices_tensor_d
|
||||
state_indices_tensor_d_output = state_indices_tensor_d
|
||||
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
conv_state,
|
||||
self.conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
block_idx_last_scheduled_token=block_idx_last_scheduled_token_d,
|
||||
initial_state_idx=block_idx_last_computed_token_d,
|
||||
)
|
||||
|
||||
hidden_states_d, B_d, C_d = self.split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d
|
||||
)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A_d = (
|
||||
self.A[:, None, ...][:, :, None]
|
||||
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
||||
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim
|
||||
)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
selective_state_update(
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
B_d,
|
||||
C_d,
|
||||
D_d,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=self.intermediate_size,
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
n_groups=self.n_groups,
|
||||
num_heads=self.num_heads,
|
||||
head_dim=self.head_dim,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
projected_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.conv_ssm_forward(projected_states=projected_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer2_fake(
|
||||
projected_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer2",
|
||||
op_func=mamba_mixer2,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer2_fake,
|
||||
)
|
||||
225
vllm/model_executor/layers/mamba/mamba_utils.py
Normal file
225
vllm/model_executor/layers/mamba/mamba_utils.py
Normal file
@@ -0,0 +1,225 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.cache import MambaDType
|
||||
from vllm.config.model import ModelDType
|
||||
from vllm.distributed import divide
|
||||
from vllm.utils.torch_utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
get_kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
class MambaStateDtypeCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires testing
|
||||
if mamba_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for minimax is not yet supported")
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(
|
||||
model_dtype, mamba_cache_dtype, mamba_ssm_cache_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
if mamba_ssm_cache_dtype == "auto":
|
||||
temporal_state_dtype = conv_state_dtype
|
||||
else:
|
||||
temporal_state_dtype = STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype]
|
||||
|
||||
return (conv_state_dtype, temporal_state_dtype)
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (conv_state_dtype,)
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype)
|
||||
|
||||
@classmethod
|
||||
def kda_state_dtype(
|
||||
cls,
|
||||
model_dtype: ModelDType | torch.dtype,
|
||||
mamba_cache_dtype: MambaDType,
|
||||
):
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype, state_dtype, torch.float32)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape,)
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_k_heads: int,
|
||||
num_v_heads: int,
|
||||
head_k_dim: int,
|
||||
head_v_dim: int,
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
):
|
||||
conv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, tp_world_size),
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(num_v_heads, tp_world_size),
|
||||
head_k_dim,
|
||||
head_v_dim,
|
||||
)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def kda_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_k_heads: int | None = None,
|
||||
head_k_dim: int | None = None,
|
||||
conv_kernel_size: int = 4,
|
||||
num_spec: int = 0,
|
||||
) -> tuple[tuple[int, int], tuple[int, int], tuple[int, int], tuple[int, int, int]]:
|
||||
if num_k_heads is None:
|
||||
num_k_heads = num_heads
|
||||
if head_k_dim is None:
|
||||
head_k_dim = head_dim
|
||||
|
||||
proj_size = num_heads * head_dim
|
||||
proj_k_size = num_k_heads * head_k_dim
|
||||
|
||||
conv_state_shape = (divide(proj_size, tp_world_size), conv_kernel_size - 1)
|
||||
conv_state_k_shape = (divide(proj_k_size, tp_world_size), conv_kernel_size - 1)
|
||||
recurrent_state_shape = (divide(num_heads, tp_world_size), head_dim, head_dim)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
conv_state_k_shape = conv_state_k_shape[1], conv_state_k_shape[0]
|
||||
return (
|
||||
conv_state_shape,
|
||||
conv_state_k_shape,
|
||||
conv_state_k_shape,
|
||||
recurrent_state_shape,
|
||||
)
|
||||
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
1240
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
1240
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
File diff suppressed because it is too large
Load Diff
172
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
172
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row: tl.int64,
|
||||
stride_y_row: tl.int64,
|
||||
stride_z_row: tl.int64,
|
||||
M: tl.int64, # number of rows in X
|
||||
N: tl.int64, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
return y.reshape(x_shape_og)
|
||||
586
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
586
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
||||
is not None
|
||||
}
|
||||
)
|
||||
@triton.heuristics(
|
||||
{"IS_SPEC_DECODING": lambda args: args["num_accepted_tokens_ptr"] is not None}
|
||||
)
|
||||
@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens_ptr"] is not None})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
||||
)
|
||||
@triton.jit(do_not_specialize=["N"])
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
A_ptr,
|
||||
B_ptr,
|
||||
C_ptr,
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
dst_state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
num_accepted_tokens_ptr,
|
||||
cu_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
N,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch,
|
||||
stride_state_head,
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_bias_dim,
|
||||
stride_A_head,
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
stride_state_indices_batch,
|
||||
stride_state_indices_T,
|
||||
stride_dst_state_indices_batch,
|
||||
stride_dst_state_indices_T,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
IS_SPEC_DECODING: tl.constexpr,
|
||||
IS_VARLEN: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
if IS_VARLEN:
|
||||
bos = tl.load(cu_seqlens_ptr + pid_b).to(tl.int64)
|
||||
eos = tl.load(cu_seqlens_ptr + pid_b + 1).to(tl.int64)
|
||||
seq_len = eos - bos
|
||||
|
||||
if seq_len == 0:
|
||||
return
|
||||
else:
|
||||
bos = pid_b
|
||||
seq_len = 1
|
||||
|
||||
state_ptr_base = state_ptr
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
if IS_SPEC_DECODING:
|
||||
num_accepted = tl.load(num_accepted_tokens_ptr + pid_b).to(tl.int64)
|
||||
init_token_idx = tl.maximum(num_accepted - 1, 0)
|
||||
else:
|
||||
init_token_idx = 0
|
||||
|
||||
dst_state_batch_indices_ptr += pid_b * stride_dst_state_indices_batch
|
||||
if not IS_SPEC_DECODING:
|
||||
dst_state_batch_idx = tl.load(
|
||||
dst_state_batch_indices_ptr
|
||||
+ init_token_idx * stride_dst_state_indices_T
|
||||
).to(tl.int64)
|
||||
dst_state_ptr = state_ptr + (
|
||||
dst_state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
|
||||
state_batch_indices_ptr += (
|
||||
pid_b * stride_state_indices_batch + init_token_idx * stride_state_indices_T
|
||||
)
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
else:
|
||||
dst_state_ptr = (
|
||||
state_ptr + pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
)
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += bos * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += bos * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += bos * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += bos * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += bos * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += bos * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
if not IS_SPEC_DECODING:
|
||||
dst_state_ptrs = dst_state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
A_ptrs = A_ptr + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
||||
|
||||
for i_t in range(seq_len):
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(
|
||||
A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
if IS_SPEC_DECODING:
|
||||
dst_idx_ptr = dst_state_batch_indices_ptr + i_t * stride_dst_state_indices_T
|
||||
token_dst_idx = tl.load(dst_idx_ptr).to(tl.int64)
|
||||
if token_dst_idx != pad_slot_id:
|
||||
token_dst_ptrs = (
|
||||
state_ptr_base
|
||||
+ token_dst_idx * stride_state_batch
|
||||
+ pid_h * stride_state_head
|
||||
+ offs_m[:, None] * stride_state_dim
|
||||
+ offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
tl.store(
|
||||
token_dst_ptrs, state.to(token_dst_ptrs.dtype.element_ty), mask=mask
|
||||
)
|
||||
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
x_ptr += stride_x_batch
|
||||
dt_ptr += stride_dt_batch
|
||||
B_ptr += stride_B_batch
|
||||
C_ptr += stride_C_batch
|
||||
out_ptr += stride_out_batch
|
||||
if HAS_Z:
|
||||
z_ptr += stride_z_batch
|
||||
|
||||
if not IS_SPEC_DECODING:
|
||||
tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask)
|
||||
|
||||
|
||||
def selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
dst_state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None,
|
||||
num_accepted_tokens=None,
|
||||
cu_seqlens=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
num_accepted_tokens: (batch,)
|
||||
number of accepted tokens from previous verification step,
|
||||
tells the kernel which initial state to use
|
||||
cu_seqlens: (batch,)
|
||||
length per sequence, for variable length in speculative decoding cases
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
if num_accepted_tokens is not None:
|
||||
assert state_batch_indices is not None and state_batch_indices.dim() == 2
|
||||
assert dst_state_batch_indices is None or dst_state_batch_indices.dim() == 2
|
||||
if state_batch_indices is not None and state_batch_indices.dim() == 1:
|
||||
state_batch_indices = state_batch_indices.unsqueeze(1)
|
||||
if dst_state_batch_indices is not None and dst_state_batch_indices.dim() == 1:
|
||||
dst_state_batch_indices = dst_state_batch_indices.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
if cu_seqlens is not None:
|
||||
N = len(cu_seqlens) - 1
|
||||
# Only used to verify the shape of
|
||||
# state_batch_indices and dst_state_batch_indices
|
||||
max_seqlen = (
|
||||
state_batch_indices.size(-1) if state_batch_indices is not None else 1
|
||||
)
|
||||
else:
|
||||
N = batch
|
||||
max_seqlen = 1
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape[0] >= N
|
||||
assert state_batch_indices.shape[1] >= max_seqlen
|
||||
if dst_state_batch_indices is not None:
|
||||
assert dst_state_batch_indices.shape[0] >= N
|
||||
assert dst_state_batch_indices.shape[1] >= max_seqlen
|
||||
else:
|
||||
# revert to the default behavior of in-place state updates
|
||||
dst_state_batch_indices = state_batch_indices
|
||||
assert out.shape == x.shape
|
||||
if num_accepted_tokens is not None:
|
||||
assert num_accepted_tokens.shape == (N,)
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), N, nheads)
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
state_batch_indices_strides = (
|
||||
(state_batch_indices.stride(0), state_batch_indices.stride(1))
|
||||
if state_batch_indices is not None
|
||||
else (0, 0)
|
||||
)
|
||||
dst_state_batch_indices_strides = (
|
||||
(dst_state_batch_indices.stride(0), dst_state_batch_indices.stride(1))
|
||||
if dst_state_batch_indices is not None
|
||||
else (0, 0)
|
||||
)
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = (
|
||||
(32, 4)
|
||||
if dstate <= 16
|
||||
else (
|
||||
(16, 4)
|
||||
if dstate <= 32
|
||||
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
||||
)
|
||||
)
|
||||
tie_hdim = (
|
||||
A.stride(-1) == 0
|
||||
and A.stride(-2) == 0
|
||||
and dt.stride(-1) == 0
|
||||
and dt_bias.stride(-1) == 0
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
dst_state_batch_indices,
|
||||
pad_slot_id,
|
||||
num_accepted_tokens,
|
||||
cu_seqlens,
|
||||
N,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads // ngroups,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
state_batch_indices_strides[0],
|
||||
state_batch_indices_strides[1],
|
||||
dst_state_batch_indices_strides[0],
|
||||
dst_state_batch_indices_strides[1],
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
def selective_scan_fn(
|
||||
u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
block_size=1024,
|
||||
block_idx_first_scheduled_token=None,
|
||||
block_idx_last_scheduled_token=None,
|
||||
initial_state_idx=None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
dt_bias: (dim,) or (dim)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended with 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state indices
|
||||
- Without APC: (batch,) - single state index per batch item
|
||||
- With APC: (batch, max_positions) - cache block indices for read/write
|
||||
Each non-zero value indicates a cache block to load from and/or write to.
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
block_size: int
|
||||
The block size to align the cached states to
|
||||
block_idx_first_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the first
|
||||
cache block to be filled is located.
|
||||
block_idx_last_scheduled_token: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the last cache block
|
||||
to be filled is located.
|
||||
initial_state_idx: (batch,), dtype int32
|
||||
The pointer into cache_indices, where the cache block
|
||||
containing the initial state is located.
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3 and query_start_loc is None:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 2 and query_start_loc is not None:
|
||||
B = B.unsqueeze(0)
|
||||
if C.dim() == 3 and query_start_loc is None:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 2 and query_start_loc is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
ops.selective_scan_fwd(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
block_size,
|
||||
block_idx_first_scheduled_token,
|
||||
block_idx_last_scheduled_token,
|
||||
initial_state_idx,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
else:
|
||||
return z # output written inplace to z
|
||||
211
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
211
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["chunk_size", "K", "IS_CAUSAL"],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
ngroups: tl.constexpr,
|
||||
stride_a_seqlen: tl.int64,
|
||||
stride_a_head: tl.int64,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_outm: tl.int64,
|
||||
stride_outn: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
|
||||
a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# compute a * b.T
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
|
||||
& (offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (seqlen, ngroups, k)
|
||||
b: (seqlen, ngroups, k)
|
||||
chunk_size: int
|
||||
cu_chunk_seq_lens: (nchunks+1,)
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||
b = b.contiguous()
|
||||
|
||||
nchunks = len(cu_chunk_seqlens) - 1
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty(
|
||||
(nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype
|
||||
)
|
||||
dot_dtype = (
|
||||
tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
|
||||
else (
|
||||
tl.float16
|
||||
if a.dtype == torch.float16 or b.dtype == torch.float16
|
||||
else tl.float32
|
||||
)
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
|
||||
nchunks * ngroups,
|
||||
)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
out_ptr=out,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
seqlen=seqlen,
|
||||
chunk_size=chunk_size,
|
||||
K=k,
|
||||
ngroups=ngroups,
|
||||
stride_a_seqlen=a.stride(0),
|
||||
stride_a_head=a.stride(1),
|
||||
stride_ak=a.stride(2),
|
||||
stride_b_seqlen=b.stride(0),
|
||||
stride_b_head=b.stride(1),
|
||||
stride_bk=b.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_outm=out.stride(-2),
|
||||
stride_outn=out.stride(-1),
|
||||
IS_CAUSAL=causal,
|
||||
dot_dtype=dot_dtype,
|
||||
)
|
||||
return out
|
||||
456
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
456
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
@@ -0,0 +1,456 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
cb_ptr,
|
||||
x_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
C_ptr,
|
||||
states_ptr,
|
||||
D_ptr,
|
||||
initstates_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
chunk_size: tl.constexpr,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_cb_chunk: tl.int64,
|
||||
stride_cb_head: tl.int64,
|
||||
stride_cb_csize_m: tl.int64,
|
||||
stride_cb_csize_k: tl.constexpr,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_z_seqlen: tl.int64,
|
||||
stride_z_head: tl.int64,
|
||||
stride_z_hdim: tl.constexpr,
|
||||
stride_out_seqlen: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_hdim: tl.constexpr,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
stride_C_seqlen: tl.int64,
|
||||
stride_C_head: tl.int64,
|
||||
stride_C_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
stride_D_head: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += (
|
||||
chunk_seqlen_start * stride_C_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
)
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
|
||||
seq_idx_ptr += pid_c * stride_seq_idx_chunk
|
||||
seq_idx = tl.load(seq_idx_ptr)
|
||||
seq_idx_prev = tl.load(
|
||||
seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1
|
||||
)
|
||||
|
||||
if HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states_ptr = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_init_states_batch
|
||||
+ pid_h * stride_init_states_head
|
||||
)
|
||||
prev_states_hdim = stride_init_states_hdim
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
else:
|
||||
prev_states_ptr = (
|
||||
states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head
|
||||
)
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(
|
||||
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
||||
)
|
||||
C_ptrs = C_ptr + (
|
||||
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
|
||||
)
|
||||
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
# if no init states AND starting a new sequence, we need zeros
|
||||
prev_states = tl.zeros(
|
||||
(BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
|
||||
)
|
||||
else:
|
||||
# otherwise read the previous state
|
||||
prev_states_ptrs = (
|
||||
prev_states_ptr
|
||||
+ offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
|
||||
else:
|
||||
prev_states_ptrs = (
|
||||
prev_states_ptr
|
||||
+ offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0,
|
||||
)
|
||||
if not HAS_INITSTATES and (seq_idx != seq_idx_prev):
|
||||
prev_states = tl.zeros(
|
||||
(BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty
|
||||
)
|
||||
else:
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k)
|
||||
& (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
cb_ptrs = cb_ptr + (
|
||||
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
||||
)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = (
|
||||
chunk_size_limit
|
||||
if not IS_CAUSAL
|
||||
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
)
|
||||
for k in range(0, K_MAX, BLOCK_SIZE_K):
|
||||
cb = tl.load(
|
||||
cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
mask = offs_m[:, None] >= k + offs_k[None, :]
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(x_ptr.dtype.element_ty)
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
acc += tl.dot(cb, x)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
if HAS_D:
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(
|
||||
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
||||
).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
x_residual = tl.load(
|
||||
x_ptr
|
||||
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (
|
||||
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
|
||||
)
|
||||
z = tl.load(
|
||||
z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit)
|
||||
& (offs_out_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (
|
||||
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
|
||||
)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
|
||||
)
|
||||
|
||||
|
||||
def _chunk_scan_fwd(
|
||||
cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
cu_chunk_seqlens,
|
||||
out,
|
||||
seq_idx,
|
||||
D=None,
|
||||
z=None,
|
||||
initial_states=None,
|
||||
):
|
||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = C.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert C.shape == (seqlen, ngroups, dstate)
|
||||
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (nchunks,)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
||||
nchunks,
|
||||
nheads,
|
||||
)
|
||||
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb_ptr=cb,
|
||||
x_ptr=x,
|
||||
z_ptr=z,
|
||||
out_ptr=out,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
C_ptr=C,
|
||||
states_ptr=states,
|
||||
D_ptr=D,
|
||||
initstates_ptr=initial_states,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_cb_chunk=cb.stride(0),
|
||||
stride_cb_head=cb.stride(1),
|
||||
stride_cb_csize_m=cb.stride(2),
|
||||
stride_cb_csize_k=cb.stride(3),
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_z_seqlen=z_strides[0],
|
||||
stride_z_head=z_strides[1],
|
||||
stride_z_hdim=z_strides[2],
|
||||
stride_out_seqlen=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_hdim=out.stride(2),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
stride_C_seqlen=C.stride(0),
|
||||
stride_C_head=C.stride(1),
|
||||
stride_C_dstate=C.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
stride_D_head=D.stride(0) if D is not None else 0,
|
||||
IS_CAUSAL=True,
|
||||
HAS_D=D is not None,
|
||||
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
|
||||
HAS_Z=z is not None,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return
|
||||
700
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
700
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
@@ -0,0 +1,700 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE_H": 2}),
|
||||
triton.Config({"BLOCK_SIZE_H": 4}),
|
||||
triton.Config({"BLOCK_SIZE_H": 8}),
|
||||
triton.Config({"BLOCK_SIZE_H": 16}),
|
||||
triton.Config({"BLOCK_SIZE_H": 32}),
|
||||
triton.Config({"BLOCK_SIZE_H": 64}),
|
||||
],
|
||||
key=["chunk_size", "nheads"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
dt_ptr,
|
||||
A_ptr,
|
||||
dt_bias_ptr,
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimension
|
||||
seqlen,
|
||||
nheads: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
dt_min: tl.constexpr,
|
||||
dt_max: tl.constexpr,
|
||||
# Strides
|
||||
stride_dt_seqlen: tl.int64,
|
||||
stride_dt_head: tl.constexpr,
|
||||
stride_A_head: tl.constexpr,
|
||||
stride_dt_bias_head: tl.constexpr,
|
||||
stride_dt_out_head: tl.int64,
|
||||
stride_dt_out_chunk: tl.int64,
|
||||
stride_dt_out_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr,
|
||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
# if dt is long, may cause problems, so use 64 bit
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
|
||||
dt_ptr += chunk_seqlen_start * stride_dt_seqlen
|
||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (
|
||||
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
||||
)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (
|
||||
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
||||
)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (
|
||||
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
||||
)
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
dt = tl.load(
|
||||
dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(
|
||||
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
||||
).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
|
||||
dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
||||
)
|
||||
tl.store(
|
||||
dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(
|
||||
dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
states_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c)
|
||||
chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1)
|
||||
b_ptr += (
|
||||
chunk_seqlen_start * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
||||
tl.float32
|
||||
)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
num_stages=3,
|
||||
num_warps=8,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=4,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=5,
|
||||
num_warps=2,
|
||||
),
|
||||
triton.Config(
|
||||
{"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
num_stages=4,
|
||||
num_warps=2,
|
||||
),
|
||||
],
|
||||
key=["hdim", "dstate", "chunk_size"],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
chunk_states_ptr,
|
||||
cu_seqlens_ptr,
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_chunk_states_chunk: tl.int64,
|
||||
stride_chunk_states_head: tl.int64,
|
||||
stride_chunk_states_hdim: tl.int64,
|
||||
stride_chunk_states_dstate: tl.constexpr,
|
||||
stride_states_batch: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += (
|
||||
pid_c * chunk_size * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += (
|
||||
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
# are boundary conditions at a chunk boundary) and initstates (which are boundary
|
||||
# conditions when a new example in a cont batch starts)
|
||||
initstates_ptr += pid_h * stride_init_states_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(
|
||||
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
||||
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim)
|
||||
& (offs_k[None, :] < chunk_size_limit - k)
|
||||
& (offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k)
|
||||
& (offs_n[None, :] < dstate)
|
||||
& (offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
|
||||
0.0,
|
||||
)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if (
|
||||
(start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)
|
||||
):
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch
|
||||
+ offs_m[:, None] * stride_init_states_hdim
|
||||
+ offs_n[None, :] * stride_init_states_dstate
|
||||
)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cumsum_ptr
|
||||
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
|
||||
past_states = tl.load(
|
||||
past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
):
|
||||
seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads,)
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads,)
|
||||
nchunks = cu_chunk_seqlens.shape[0] - 1
|
||||
dt_out = torch.empty(
|
||||
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
dA_cumsum = torch.empty(
|
||||
nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"]))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt_ptr=dt,
|
||||
A_ptr=A,
|
||||
dt_bias_ptr=dt_bias,
|
||||
dt_out_ptr=dt_out,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
seqlen=seqlen,
|
||||
nheads=nheads,
|
||||
chunk_size=chunk_size,
|
||||
dt_min=dt_limit[0],
|
||||
dt_max=dt_limit[1],
|
||||
stride_dt_seqlen=dt.stride(0),
|
||||
stride_dt_head=dt.stride(1),
|
||||
stride_A_head=A.stride(0),
|
||||
stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
stride_dt_out_head=dt_out.stride(0),
|
||||
stride_dt_out_chunk=dt_out.stride(1),
|
||||
stride_dt_out_csize=dt_out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
DT_SOFTPLUS=dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True
|
||||
):
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
|
||||
if states is not None:
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty(
|
||||
(nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
nchunks,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
states_ptr=states,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(
|
||||
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
|
||||
):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
batch = cu_seqlens.shape[0] - 1
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (total_seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(
|
||||
batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device,
|
||||
)
|
||||
|
||||
initial_states_strides = (
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
batch,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
chunk_states_ptr=chunk_states,
|
||||
cu_seqlens_ptr=cu_seqlens,
|
||||
states_ptr=states,
|
||||
initstates_ptr=initial_states,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||
stride_chunk_states_head=chunk_states.stride(1),
|
||||
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||
stride_states_batch=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return states
|
||||
230
vllm/model_executor/layers/mamba/ops/ssd_combined.py
Normal file
230
vllm/model_executor/layers/mamba/ops/ssd_combined.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
return_intermediate_states=False,
|
||||
seq_idx=None,
|
||||
cu_seqlens=None,
|
||||
cu_chunk_seqlens=None,
|
||||
last_chunk_indices=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (seqlen, nheads)
|
||||
assert A.shape == (nheads,)
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,)
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if (
|
||||
x.stride(-1) != 1 and x.stride(0) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if (
|
||||
z is not None and z.stride(-1) != 1 and z.stride(0) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim, dstate)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
# which has a minimal implementation to understand the below operations
|
||||
# - as explained by the blog, mamba is a special case of causal attention
|
||||
# - the idea is to chunk the attention matrix and compute each
|
||||
# submatrix separately using different optimizations.
|
||||
# - see the blog and paper for a visualization of the submatrices
|
||||
# which we refer to in the comments below
|
||||
|
||||
# 1. Compute chunked cumsum of A * dt
|
||||
# - here dt may go through a softplus activation
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(
|
||||
dt,
|
||||
A,
|
||||
chunk_size,
|
||||
cu_chunk_seqlens,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
states = _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True
|
||||
)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states and
|
||||
# ii) seq_idx to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
cu_chunk_seqlens,
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None
|
||||
else None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
# account past causal states.
|
||||
# - if initial states are provided, then states information will be
|
||||
# augmented with initial_states.
|
||||
# - to do this properly, we need to account for example changes in
|
||||
# the continuous batch, therefore we introduce pseudo chunks, which is
|
||||
# a chunk that is split up each time an example changes.
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
_chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
cu_chunk_seqlens,
|
||||
out, # in-place update
|
||||
seq_idx,
|
||||
D=D,
|
||||
z=z,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
if return_intermediate_states:
|
||||
return states
|
||||
else:
|
||||
return states[last_chunk_indices]
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
cu_chunk_seqlens,
|
||||
last_chunk_indices,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
return_intermediate_states=False,
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (seqlen, nheads, headdim)
|
||||
dt: (seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (seqlen, ngroups, dstate)
|
||||
C: (seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
cu_seqlens: (batch + 1,)
|
||||
cu_chunk_seqlens: (nchunks + 1,)
|
||||
last_chunk_indices: (batch,)
|
||||
seq_idx: (nchunks,)
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
Return:
|
||||
varlen_states: (batch, nheads, headdim, dstate)
|
||||
"""
|
||||
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
||||
assert seq_idx is not None
|
||||
|
||||
varlen_states = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
return_intermediate_states=return_intermediate_states,
|
||||
seq_idx=seq_idx,
|
||||
cu_seqlens=cu_seqlens,
|
||||
cu_chunk_seqlens=cu_chunk_seqlens,
|
||||
last_chunk_indices=last_chunk_indices,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
state_dtype=state_dtype,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
157
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
157
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE": 64}),
|
||||
triton.Config({"BLOCK_SIZE": 128}),
|
||||
triton.Config({"BLOCK_SIZE": 256}),
|
||||
triton.Config({"BLOCK_SIZE": 512}),
|
||||
triton.Config({"BLOCK_SIZE": 1024}),
|
||||
triton.Config({"BLOCK_SIZE": 2048}),
|
||||
],
|
||||
key=["dim"],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr,
|
||||
out_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
cu_chunk_seqlens_ptr,
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
# Strides
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_dim: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_dim: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_chunk: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk)
|
||||
# we have started a new sequence
|
||||
if prev_seq_idx != seq_idx:
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptrs = (
|
||||
initstates_ptr
|
||||
+ seq_idx * stride_initstates_batch
|
||||
+ pid_h * stride_initstates_head
|
||||
+ offs_m * stride_initstates_dim
|
||||
)
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
else:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
|
||||
prev_seq_idx = seq_idx
|
||||
states = tl.exp(dA_cs) * states + new_states
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
cu_chunk_seqlens,
|
||||
seq_idx,
|
||||
initial_states=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
nchunks, nheads, dim = states.shape
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype)
|
||||
|
||||
initial_states_strides = (
|
||||
(initial_states.stride(0), initial_states.stride(1), initial_states.stride(2))
|
||||
if initial_states is not None
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states_ptr=states,
|
||||
out_ptr=out,
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_dim=states.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_dim=out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_chunk=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out
|
||||
255
vllm/model_executor/layers/mamba/short_conv.py
Normal file
255
vllm/model_executor/layers/mamba/short_conv.py
Normal file
@@ -0,0 +1,255 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator,
|
||||
MambaStateShapeCalculator,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import ShortConvAttentionMetadata
|
||||
|
||||
|
||||
@CustomOp.register("short_conv")
|
||||
class ShortConv(MambaBase, CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: ModelConfig | None = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.conv_dim = dim
|
||||
self.L_cache = config.conv_L_cache
|
||||
self.bias = config.conv_bias
|
||||
|
||||
self.conv = ColumnParallelLinear(
|
||||
input_size=self.L_cache,
|
||||
output_size=dim,
|
||||
bias=self.bias,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=dim,
|
||||
output_sizes=[dim] * 3,
|
||||
bias=self.bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=dim,
|
||||
output_size=dim,
|
||||
bias=self.bias,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.kv_cache = (torch.tensor([]),)
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
return
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
torch.ops.vllm.short_conv(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# ShortConvAttentionMetadata contains metadata necessary for the
|
||||
# short_conv triton kernels to operate in continuous batching and in
|
||||
# chunked prefill modes; they are computed at top-level model forward
|
||||
# since they stay the same and reused for all mamba layers in the same
|
||||
# iteration.
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
|
||||
BCx, _ = self.in_proj(hidden_states)
|
||||
|
||||
B, C, x = BCx.chunk(3, dim=-1)
|
||||
|
||||
conv_weights = self.conv.weight.view(
|
||||
self.conv.weight.size(0), self.conv.weight.size(2)
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
Bx = (B * x).contiguous()
|
||||
hidden_states = C * Bx
|
||||
contextualized_states, _ = self.out_proj(hidden_states)
|
||||
return contextualized_states
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_decodes + num_prefill_tokens
|
||||
|
||||
# NOTE: V1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
B_d, B_p = torch.split(
|
||||
B[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
C_d, C_p = torch.split(
|
||||
C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
x_d, x_p = torch.split(
|
||||
x[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1 :] - num_decodes
|
||||
if has_prefill
|
||||
else None
|
||||
)
|
||||
|
||||
conv_output_list = []
|
||||
|
||||
if has_prefill:
|
||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||
Bx = causal_conv1d_fn(
|
||||
Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
y = C_p * Bx
|
||||
conv_output_list.append(y)
|
||||
|
||||
if has_decode:
|
||||
Bx_d = (B_d * x_d).contiguous()
|
||||
Bx = causal_conv1d_update(
|
||||
Bx_d,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
)
|
||||
y = C_d * Bx
|
||||
conv_output_list.insert(0, y)
|
||||
|
||||
# Merge prefill and decode outputs before passing to gated MLP
|
||||
hidden_states = torch.vstack(conv_output_list)
|
||||
|
||||
# Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.short_conv_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.short_conv_state_shape(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.conv_dim,
|
||||
conv_kernel=self.L_cache,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "short_conv"
|
||||
|
||||
|
||||
def short_conv(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def short_conv_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="short_conv",
|
||||
op_func=short_conv,
|
||||
mutates_args=["output"],
|
||||
fake_impl=short_conv_fake,
|
||||
)
|
||||
Reference in New Issue
Block a user