first commit
This commit is contained in:
0
vllm/model_executor/layers/mamba/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
42
vllm/model_executor/layers/mamba/abstract.py
Normal file
42
vllm/model_executor/layers/mamba/abstract.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
class MambaBase(AttentionLayerBase):
|
||||
"""
|
||||
Base class for Mamba-like layers which support the v1 engine.
|
||||
Inherit from this class if you implement a custom layer.
|
||||
"""
|
||||
|
||||
# Contains the KV cache (mamba state) for the layer
|
||||
# in the shape specified by `self.get_state_shape`.
|
||||
kv_cache: tuple[torch.Tensor, ...]
|
||||
|
||||
@abstractmethod
|
||||
def get_state_shape(self) -> Iterable[tuple[int, ...]]:
|
||||
"""
|
||||
Defines the shape of the state.
|
||||
For mamba layers this is usually a (conv_state, ssm_state) tuple.
|
||||
In this case, returns (conv_state_shape, ssm_state_shape).
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def mamba_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
pass
|
||||
403
vllm/model_executor/layers/mamba/linear_attn.py
Normal file
403
vllm/model_executor/layers/mamba/linear_attn.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.lightning_attn import (
|
||||
lightning_attention, linear_decode_forward_triton)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
||||
class MiniMaxText01RMSNormTP(CustomOp):
|
||||
name = "MiniMaxText01RMSNormTP"
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.tp_world = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.weight = nn.Parameter(torch.ones(int(hidden_size /
|
||||
self.tp_world)))
|
||||
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
self.variance_epsilon = eps
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def weight_loader(
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
) -> None:
|
||||
tp_world = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = loaded_weight.shape[0] // tp_world
|
||||
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
return
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
|
||||
if self.tp_world > 1:
|
||||
variance = tensor_model_parallel_all_reduce(
|
||||
variance) / self.tp_world
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x.to(orig_dtype) * self.weight
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert residual is None, "RMSNorm does not support residual connection."
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class MiniMaxText01LinearKernel:
|
||||
|
||||
@staticmethod
|
||||
def jit_linear_forward_prefix(q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_caches: torch.Tensor,
|
||||
slope_rate: torch.Tensor,
|
||||
block_size: int,
|
||||
layer_idx: Optional[int] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
slope_rate = slope_rate.to(torch.float32)
|
||||
should_pad_dim = q.dim() == 3
|
||||
if should_pad_dim:
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
b, h, n, d = q.shape
|
||||
e = d
|
||||
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
|
||||
output, kv_history = lightning_attention(q,
|
||||
k,
|
||||
v,
|
||||
slope_rate,
|
||||
block_size=block_size,
|
||||
kv_history=kv_history)
|
||||
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
|
||||
assert output.shape[0] == 1, "batch size must be 1"
|
||||
return rearrange(output.squeeze(0), "h n d -> n (h d)")
|
||||
|
||||
|
||||
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "linear_attention"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.linear_attn import (
|
||||
LinearAttentionBackend)
|
||||
return LinearAttentionBackend
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.linear_attention_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
|
||||
return MambaStateShapeCalculator.linear_attention_state_shape(
|
||||
num_heads=self.num_heads,
|
||||
tp_size=self.tp_size,
|
||||
head_dim=self.head_dim)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
hidden_inner_size: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
max_position: int,
|
||||
block_size: int,
|
||||
num_hidden_layer: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
layer_idx: int = 0,
|
||||
linear_layer_idx: int = 0,
|
||||
prefix: str = "linear_attn",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = layer_idx
|
||||
self.BLOCK = block_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.total_num_heads = num_heads
|
||||
self.hidden_inner_size = hidden_inner_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert self.total_num_heads % self.tp_size == 0
|
||||
self.tp_heads = self.total_num_heads // self.tp_size
|
||||
self.qkv_size = self.num_heads * self.head_dim
|
||||
self.tp_hidden = self.head_dim * self.tp_heads
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.hidden_inner_size * 3,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.output_gate = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
self.hidden_inner_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.output_gate",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
self.hidden_inner_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
self.norm = MiniMaxText01RMSNormTP(
|
||||
self.hidden_inner_size,
|
||||
eps=1e-5,
|
||||
)
|
||||
|
||||
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
|
||||
self.num_heads)
|
||||
if num_hidden_layer <= 1:
|
||||
self.slope_rate = slope_rate * (1 + 1e-5)
|
||||
else:
|
||||
self.slope_rate = slope_rate * (1 - layer_idx /
|
||||
(num_hidden_layer - 1) + 1e-5)
|
||||
self.tp_slope = self.slope_rate[self.tp_rank *
|
||||
self.tp_heads:(self.tp_rank + 1) *
|
||||
self.tp_heads].contiguous()
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
@staticmethod
|
||||
def weight_direct_load(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _build_slope_tensor(n_attention_heads: int):
|
||||
|
||||
def get_slopes(n):
|
||||
|
||||
def get_slopes_power_of_2(n):
|
||||
start = 2**(-(2**-(math.log2(n) - 3)))
|
||||
ratio = start
|
||||
return [start * ratio**i for i in range(n)]
|
||||
|
||||
if math.log2(n).is_integer():
|
||||
return get_slopes_power_of_2(n)
|
||||
else:
|
||||
closest_power_of_2 = 2**math.floor(math.log2(n))
|
||||
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
|
||||
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
|
||||
|
||||
slopes = torch.tensor(get_slopes(n_attention_heads),
|
||||
dtype=torch.float32).reshape(
|
||||
n_attention_heads, 1, 1)
|
||||
return slopes
|
||||
|
||||
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
hidden = []
|
||||
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
|
||||
if _prefill_idx >= len(attn_metadata.query_start_loc):
|
||||
break
|
||||
if _prefill_idx >= len(state_indices_tensor):
|
||||
break
|
||||
offset = attn_metadata.num_decode_tokens
|
||||
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
|
||||
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
|
||||
slot_id = state_indices_tensor[offset + _prefill_idx]
|
||||
qs = q[_start:_end].transpose(0, 1).contiguous()
|
||||
ks = k[_start:_end].transpose(0, 1).contiguous()
|
||||
vs = v[_start:_end].transpose(0, 1).contiguous()
|
||||
slice_layer_cache = kv_cache[slot_id, ...]
|
||||
|
||||
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
|
||||
qs,
|
||||
ks,
|
||||
vs,
|
||||
slice_layer_cache,
|
||||
self.tp_slope,
|
||||
self.BLOCK,
|
||||
layer_idx=self.layer_idx)
|
||||
hidden.append(out_slice.contiguous())
|
||||
if attn_metadata.num_decode_tokens > 0:
|
||||
hidden_decode = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
hidden.insert(0, hidden_decode)
|
||||
|
||||
if not hidden:
|
||||
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
|
||||
|
||||
hidden = torch.concat(hidden, dim=0).contiguous()
|
||||
return hidden
|
||||
|
||||
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
|
||||
attn_metadata):
|
||||
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
|
||||
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
|
||||
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
|
||||
slot_id, 32)
|
||||
return hidden
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor) -> None:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
num_actual_tokens = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
num_actual_tokens = hidden_states.shape[0]
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
|
||||
qkv32 = qkv.to(torch.float32)
|
||||
qkvact = torch.nn.functional.silu(qkv32)
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
if attn_metadata is not None:
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
num_prefills = getattr(attn_metadata, "num_prefills", 0)
|
||||
if num_prefills > 0:
|
||||
num_decode_tokens = getattr(attn_metadata, "num_decode_tokens",
|
||||
0)
|
||||
for prefill_idx in range(num_prefills):
|
||||
q_start = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx]
|
||||
q_end = attn_metadata.query_start_loc[num_decode_tokens +
|
||||
prefill_idx + 1]
|
||||
query_len = q_end - q_start
|
||||
context_len = attn_metadata.seq_lens[
|
||||
num_decode_tokens + prefill_idx] - query_len
|
||||
if context_len == 0:
|
||||
block_to_clear = state_indices_tensor[num_decode_tokens
|
||||
+ prefill_idx]
|
||||
kv_cache[block_to_clear, ...] = 0
|
||||
|
||||
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
|
||||
if attn_metadata is None:
|
||||
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
|
||||
device=q.device,
|
||||
dtype=q.dtype)
|
||||
else:
|
||||
if not decode_only:
|
||||
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
else:
|
||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
hidden = self.norm._forward(hidden)
|
||||
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
||||
hidden = F.sigmoid(gate) * hidden
|
||||
hidden = hidden.to(hidden_states.dtype)
|
||||
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden)
|
||||
|
||||
|
||||
def linear_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states,
|
||||
output=output,
|
||||
positions=positions)
|
||||
|
||||
|
||||
def linear_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="linear_attention",
|
||||
op_func=linear_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=linear_attention_fake,
|
||||
)
|
||||
466
vllm/model_executor/layers/mamba/mamba_mixer.py
Normal file
466
vllm/model_executor/layers/mamba/mamba_mixer.py
Normal file
@@ -0,0 +1,466 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, NamedTuple, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register("mamba_mixer")
|
||||
class MambaMixer(MambaBase, CustomOp):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||
for why A isn't selective) ∆, B, C are input-dependent
|
||||
(this is a key difference between Mamba and the linear time
|
||||
invariant S4, and is why Mamba is called
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
time_step_rank: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
use_rms_norm: bool,
|
||||
rms_norm_has_weight: bool = True,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation="silu",
|
||||
is_lora_enabled: bool = False,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.time_step_rank = time_step_rank
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.use_rms_norm = use_rms_norm
|
||||
self.activation = activation
|
||||
self.is_lora_enabled = is_lora_enabled
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.intermediate_size = intermediate_size
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=intermediate_size,
|
||||
bias=use_conv_bias,
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=use_bias)
|
||||
|
||||
# selective projection used to make dt, B and C input dependent
|
||||
self.x_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
time_step_rank + ssm_state_size * 2,
|
||||
bias=False,
|
||||
)
|
||||
# time step projection (discretization) -
|
||||
# In the forward we need to apply dt_proj without the bias,
|
||||
# as the bias is added in the selective scan kernel.
|
||||
self.dt_proj = ColumnParallelLinear(time_step_rank,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
skip_bias_add=True)
|
||||
|
||||
def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
param.data.copy_(
|
||||
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
|
||||
dim=0)[tp_rank])
|
||||
|
||||
def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
|
||||
weight_loader(param, -torch.exp(loaded_weight.float()))
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
intermediate_size // tp_size,
|
||||
ssm_state_size,
|
||||
dtype=torch.float32,
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": weight_loader})
|
||||
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
self.dt_layernorm = RMSNorm(
|
||||
time_step_rank,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
self.b_layernorm = RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
self.c_layernorm = RMSNorm(
|
||||
ssm_state_size,
|
||||
eps=rms_norm_eps,
|
||||
has_weight=rms_norm_has_weight,
|
||||
) if use_rms_norm else None
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The inner tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def _ssm_transform(
|
||||
self, x: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if self.is_lora_enabled:
|
||||
# Lora kernel requires contiguous tensor.
|
||||
ssm_params = self.x_proj(x.contiguous())[0]
|
||||
else:
|
||||
ssm_params = self.x_proj(x)[0]
|
||||
time_step, B, C = torch.split(
|
||||
ssm_params,
|
||||
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
|
||||
dim=-1)
|
||||
if self.use_rms_norm:
|
||||
assert self.dt_layernorm is not None
|
||||
assert self.b_layernorm is not None
|
||||
assert self.c_layernorm is not None
|
||||
time_step = self.dt_layernorm(time_step.contiguous())
|
||||
B = self.b_layernorm(B.contiguous())
|
||||
C = self.c_layernorm(C.contiguous())
|
||||
discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
|
||||
return discrete_time_step, B, C
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self, hidden_states: torch.Tensor,
|
||||
output: torch.Tensor):
|
||||
pass
|
||||
|
||||
def forward_cuda(self, hidden_states: torch.Tensor, output: torch.Tensor):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
|
||||
Steps
|
||||
-----
|
||||
1. Apply the gated-MLP linear projection to the raw input.
|
||||
2. Pass the projected sequence through the convolutional mixing layer.
|
||||
3. Feed the result into the State-Space Model (SSM) blocks.
|
||||
4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
to produce contextual representations.
|
||||
5. Project the contextualised sequence back
|
||||
to the output embedding dimension.
|
||||
|
||||
Batch handling
|
||||
--------------
|
||||
Prefill and decode tokens are processed by dedicated CUDA
|
||||
kernels for both the convolutional (conv1d) and SSM stages.
|
||||
In the case of a mixed batch (containing both prefill and
|
||||
decode tokens), both sets of kernels are executed independently
|
||||
and their outputs are concatenated before the final output projection.
|
||||
"""
|
||||
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
mamba1_metadata = attn_metadata
|
||||
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
|
||||
query_start_loc = mamba1_metadata.query_start_loc
|
||||
state_indices_tensor = mamba1_metadata.state_indices_tensor
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
hidden_states_BC, gate = projected_states.chunk(2, dim=-2)
|
||||
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
hidden_states_BC = hidden_states_BC.contiguous()
|
||||
return self.out_proj(hidden_states_BC.transpose(-2, -1))[0]
|
||||
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
has_prefill = num_prefill_tokens > 0
|
||||
has_decode = num_decode_tokens > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC,
|
||||
gate,
|
||||
state_indices_tensor,
|
||||
query_start_loc,
|
||||
has_initial_states,
|
||||
num_prefill_tokens,
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
num_decodes,
|
||||
num_padded_decodes,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
gate_p = prefill_decode_split.gate_p
|
||||
gate_d = prefill_decode_split.gate_d
|
||||
state_indices_tensor_p = prefill_decode_split.state_indices_tensor_p
|
||||
state_indices_tensor_d = prefill_decode_split.state_indices_tensor_d
|
||||
query_start_loc_p = prefill_decode_split.query_start_loc_p
|
||||
has_initial_states_p = prefill_decode_split.has_initial_states_p
|
||||
|
||||
ssm_outputs = []
|
||||
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_p = causal_conv1d_fn(
|
||||
hidden_states_BC_p,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
# 3. State Space Model sequence transformations.
|
||||
discrete_time_step_p, B_p, C_p = self._ssm_transform(
|
||||
conv_out_p.transpose(-2, -1))
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_out_p = selective_scan_fn(
|
||||
conv_out_p,
|
||||
ssm_state,
|
||||
discrete_time_step_p,
|
||||
self.A,
|
||||
B_p.transpose(-2, -1),
|
||||
C_p.transpose(-2, -1),
|
||||
self.D.float(),
|
||||
gate_p,
|
||||
time_proj_bias,
|
||||
delta_softplus=True,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
has_initial_state=has_initial_states_p,
|
||||
query_start_loc=query_start_loc_p)
|
||||
ssm_outputs.append(scan_out_p)
|
||||
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
conv_out_d = causal_conv1d_update(
|
||||
hidden_states_BC_d.transpose(0, 1),
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d).transpose(0, 1)
|
||||
|
||||
# 3. State Space Model sequence transformation.
|
||||
discrete_time_step_d, B_d, C_d = self._ssm_transform(
|
||||
conv_out_d.transpose(-2, -1))
|
||||
time_proj_bias = self._time_proj_bias()
|
||||
|
||||
# 4. Perform the recurrence y ← SSM(A, B, C, Δ)(x)
|
||||
scan_outputs_d = torch.empty_like(
|
||||
hidden_states_BC_d.transpose(0, 1))
|
||||
selective_state_update(ssm_state,
|
||||
conv_out_d.transpose(0, 1),
|
||||
discrete_time_step_d.transpose(0, 1),
|
||||
self.A,
|
||||
B_d,
|
||||
C_d,
|
||||
self.D,
|
||||
gate_d.transpose(0, 1),
|
||||
time_proj_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=scan_outputs_d)
|
||||
scan_outputs_d = scan_outputs_d.transpose(0, 1)
|
||||
|
||||
ssm_outputs.insert(0, scan_outputs_d)
|
||||
|
||||
scan_outputs_combined = ssm_outputs[0] if len(
|
||||
ssm_outputs) == 1 else torch.cat(ssm_outputs, dim=-1)
|
||||
|
||||
# 5. Final output projection
|
||||
if self.is_lora_enabled: # Lora kernel requires contiguous tensor.
|
||||
scan_outputs_combined = scan_outputs_combined.transpose(
|
||||
-2, -1).contiguous()
|
||||
out = self.out_proj(scan_outputs_combined)[0]
|
||||
else:
|
||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||
|
||||
output[:num_actual_tokens] = out
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.mamba1_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba1_state_shape(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.intermediate_size,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba1"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba1_attn import (
|
||||
Mamba1AttentionBackend)
|
||||
return Mamba1AttentionBackend
|
||||
|
||||
def _time_proj_bias(self) -> Optional[torch.Tensor]:
|
||||
if hasattr(self.dt_proj, "bias") and self.dt_proj.bias is not None:
|
||||
return self.dt_proj.bias.float()
|
||||
return None
|
||||
|
||||
|
||||
class PrefillDecodeSplit(NamedTuple):
|
||||
hidden_states_BC_p: torch.Tensor
|
||||
hidden_states_BC_d: torch.Tensor
|
||||
gate_p: torch.Tensor
|
||||
gate_d: torch.Tensor
|
||||
state_indices_tensor_p: torch.Tensor
|
||||
state_indices_tensor_d: torch.Tensor
|
||||
query_start_loc_p: Optional[torch.Tensor]
|
||||
has_initial_states_p: Optional[torch.Tensor]
|
||||
|
||||
|
||||
def split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
state_indices_tensor: torch.Tensor,
|
||||
query_start_loc: torch.Tensor,
|
||||
has_initial_states: Optional[torch.Tensor],
|
||||
num_prefill_tokens: int,
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
|
||||
return PrefillDecodeSplit(
|
||||
hidden_states_BC_p=hidden_states_BC_p,
|
||||
hidden_states_BC_d=hidden_states_BC_d,
|
||||
gate_p=gate_p,
|
||||
gate_d=gate_d,
|
||||
state_indices_tensor_p=state_indices_tensor_p,
|
||||
state_indices_tensor_d=state_indices_tensor_d,
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
)
|
||||
|
||||
|
||||
def mamba_mixer(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer",
|
||||
op_func=mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer_fake,
|
||||
)
|
||||
764
vllm/model_executor/layers/mamba/mamba_mixer2.py
Normal file
764
vllm/model_executor/layers/mamba/mamba_mixer2.py
Normal file
@@ -0,0 +1,764 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba2.modeling_mamba2.MambaRMSNormGated
|
||||
@CustomOp.register("mixer2_gated_rms_norm")
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
def __init__(self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance +
|
||||
self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * nn.functional.silu(gate.to(
|
||||
torch.float32)).to(input_dtype)
|
||||
|
||||
if (((self.n_groups % self.tp_size) != 0) or self.n_groups != 1):
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return rms_norm_gated(x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False)
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
shard_spec: list[tuple[int, int, float]],
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> LoaderFunction:
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures that all the groups corresponding to a head shard is placed
|
||||
together with it.
|
||||
"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
|
||||
# - track boundary of (sharded) param, and loaded_weight, respectively
|
||||
boundary, loaded_boundary = 0, 0
|
||||
|
||||
# - iterate over the shard specs
|
||||
for full_dim, extra, duplicate_groups in shard_spec:
|
||||
# - full dim is the model dim (before TP).
|
||||
# - extra > 0, means there is expected overall increase
|
||||
# of dimensions. This is so because of replication.
|
||||
# - ratio is used map the tp_rank to the actual shard
|
||||
# rank. This is useful when there is replication of
|
||||
# groups to accompany head shards.
|
||||
|
||||
# - size of the loaded shard
|
||||
shard_size = full_dim // tp_size
|
||||
|
||||
# - compute the rank into the loaded shard.
|
||||
# - if there is replication, different TP shards will
|
||||
# take from the same rank.
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0 if duplicate_groups else tp_rank
|
||||
|
||||
# - leftmost boundary index into loaded weight.
|
||||
loaded_skip = rank * shard_size
|
||||
loaded_start_idx = loaded_boundary + loaded_skip
|
||||
|
||||
# - take these many dims from the loaded weight.
|
||||
take = min(shard_size, full_dim - extra - loaded_skip)
|
||||
|
||||
# - always shard on dim 0
|
||||
# - the ignore is for a mundane mypy error as it does not
|
||||
# seem to handle slices well.
|
||||
# https://github.com/python/mypy/issues/2410
|
||||
param.data[
|
||||
boundary:(boundary + take),
|
||||
... # type: ignore[misc]
|
||||
] = loaded_weight[loaded_start_idx:(loaded_start_idx +
|
||||
take) # type: ignore[misc]
|
||||
] # type: ignore[misc]
|
||||
|
||||
# move indexing boundaries
|
||||
boundary += shard_size
|
||||
loaded_boundary += full_dim - extra
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
|
||||
@CustomOp.register("mamba_mixer2")
|
||||
class MambaMixer2(MambaBase, CustomOp):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||
for why A isn't selective) ∆, B, C are input-dependent
|
||||
(this is a key difference between Mamba and the linear time
|
||||
invariant S4, and is why Mamba is called
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
# For TP, the sharding plan is as follows:
|
||||
# - for the conv modules, since
|
||||
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
|
||||
# we shard intermediate_size and n_groups
|
||||
# - since intermediate_size = n_heads * head_dim, sharding on
|
||||
# intermediate_size is achieved by sharding on n_heads.
|
||||
# - IF, world_size divides groups, then sharding
|
||||
# (n_groups / world_size, n_heads / world_size)
|
||||
# also maintains the invariant n_heads % n_groups == 0
|
||||
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
|
||||
# to allocate extra space in the shard, such that groups
|
||||
# may be replicated to follow the head shard.
|
||||
# - NOTE: currently for the world size DOES NOT divide groups
|
||||
# case, we only support the case when n_groups == 1
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert (num_heads % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide num heads."
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||
"If tensor parallel world size does not divide num_groups, "
|
||||
"then num_groups must equal 1.")
|
||||
|
||||
assert (n_groups % self.tp_size == 0) or self.tp_size == 1 or \
|
||||
quant_config is None, (
|
||||
"Tensor parallel currently supported for quantized models only "
|
||||
"if tensor parallel world size divides num groups."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.activation = activation
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.n_groups = n_groups
|
||||
if n_groups % self.tp_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||
|
||||
if n_groups % self.tp_size == 0:
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
else:
|
||||
# This is the n_groups == 1 case,
|
||||
# where we need to duplicate groups if TP>1.
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||
# for conv1d.bias, covn1d.weight and in_proj.weight
|
||||
# - need to set these settings, to assign the groups
|
||||
# to the head shards
|
||||
group_shard_settings = (
|
||||
self.groups_ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) *
|
||||
self.ssm_state_size, # extra dims assigned
|
||||
n_groups == 1, # if there was only one group
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
head_settings = (self.num_heads, 0, False)
|
||||
|
||||
# - the weight already has a "weight_loader" attribute
|
||||
# which set_weight_attrs will raise if we do not
|
||||
# delete before trying to override it
|
||||
# - ditto for the other two weights below
|
||||
delattr(self.conv1d.bias, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
# - quant layers do not have a weight loader
|
||||
delattr(self.in_proj.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.in_proj.weight,
|
||||
{
|
||||
"weight_loader":
|
||||
mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings, # for gate
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
head_settings, # for dt
|
||||
],
|
||||
self.tp_size,
|
||||
tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
|
||||
# and `set_weight_attrs` doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
# - these are TPed by heads to reduce the size of the
|
||||
# temporal shape
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
divide(num_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
))
|
||||
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.use_rms_norm = use_rms_norm
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float()))
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
set_weight_attrs(self.dt_bias,
|
||||
{"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
self.norm = Mixer2RMSNormGated(intermediate_size,
|
||||
n_groups,
|
||||
self.use_rms_norm,
|
||||
eps=rms_norm_eps)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# attn_metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, Mamba2AttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
# conv_state = (..., dim, width-1) yet contiguous along 'dim'
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states_p
|
||||
prep_initial_states = attn_metadata.prep_initial_states
|
||||
chunk_size = attn_metadata.chunk_size
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
|
||||
if mup_vector is not None:
|
||||
projected_states = projected_states * mup_vector
|
||||
|
||||
gate, hidden_states_B_C, dt = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.conv_dim // self.tp_size,
|
||||
self.num_heads // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
|
||||
self.conv1d.weight.size(2))
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
if attn_metadata is None:
|
||||
# profile run
|
||||
hidden_states_B_C = (hidden_states_B_C.transpose(
|
||||
0, 1).clone().transpose(0, 1)).contiguous()
|
||||
hidden_states, _B, _C = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C)
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
num_prefill_tokens + num_decodes,
|
||||
(self.num_heads // self.tp_size) * self.head_dim
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_d, preallocated_ssm_out_p = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if has_prefill:
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1) # this is the form that causal-conv see
|
||||
hidden_states_B_C_p = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_p)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
if (has_initial_states_p is not None and prep_initial_states):
|
||||
initial_states = torch.where(
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt_p,
|
||||
self.A,
|
||||
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
seq_idx=seq_idx_p,
|
||||
chunk_indices=chunk_indices_p,
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C_d = causal_conv1d_update(
|
||||
hidden_states_B_C_d,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(
|
||||
hidden_states_B_C_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A_d = self.A[:, None, ...][:, :, None].expand(
|
||||
-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
|
||||
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
||||
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
selective_state_update(
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
B_d,
|
||||
C_d,
|
||||
D_d,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1,
|
||||
self.head_dim),
|
||||
)
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(preallocated_ssm_out,
|
||||
gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.mamba2_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
self.cache_config.mamba_ssm_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.mamba2_state_shape(
|
||||
intermediate_size=self.intermediate_size,
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
n_groups=self.n_groups,
|
||||
num_heads=self.num_heads,
|
||||
head_dim=self.head_dim,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionBackend)
|
||||
return Mamba2AttentionBackend
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mup_vector=mup_vector)
|
||||
|
||||
|
||||
def mamba_mixer2_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer2",
|
||||
op_func=mamba_mixer2,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer2_fake,
|
||||
)
|
||||
186
vllm/model_executor/layers/mamba/mamba_utils.py
Normal file
186
vllm/model_executor/layers/mamba/mamba_utils.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import MambaDType, ModelDType
|
||||
from vllm.distributed import divide
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_kv_cache_torch_dtype
|
||||
|
||||
|
||||
class MambaStateDtypeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
# TODO (tdoublep) requires testing
|
||||
if mamba_cache_dtype == "float32":
|
||||
raise ValueError("fp32 state for minimax is not yet supported")
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, )
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype)
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
return cls._mamba_state_dtype(model_dtype, mamba_cache_dtype,
|
||||
mamba_ssm_cache_dtype)
|
||||
|
||||
@classmethod
|
||||
def _mamba_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
mamba_ssm_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||
model_dtype)
|
||||
if mamba_ssm_cache_dtype == "auto":
|
||||
temporal_state_dtype = conv_state_dtype
|
||||
else:
|
||||
temporal_state_dtype = (
|
||||
STR_DTYPE_TO_TORCH_DTYPE[mamba_ssm_cache_dtype])
|
||||
|
||||
return (conv_state_dtype, temporal_state_dtype)
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, ...]:
|
||||
conv_state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype,
|
||||
model_dtype)
|
||||
return (conv_state_dtype, )
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_dtype(
|
||||
cls,
|
||||
model_dtype: Union[ModelDType, torch.dtype],
|
||||
mamba_cache_dtype: MambaDType,
|
||||
) -> tuple[torch.dtype, torch.dtype]:
|
||||
state_dtype = get_kv_cache_torch_dtype(mamba_cache_dtype, model_dtype)
|
||||
return (state_dtype, state_dtype)
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape, )
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size,
|
||||
tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(
|
||||
n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads,
|
||||
tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape, )
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
@classmethod
|
||||
def gated_delta_net_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
num_k_heads: int,
|
||||
num_v_heads: int,
|
||||
head_k_dim: int,
|
||||
head_v_dim: int,
|
||||
conv_kernel_size: int,
|
||||
num_spec: int = 0,
|
||||
):
|
||||
conv_dim = (head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, tp_world_size),
|
||||
conv_kernel_size - 1 + num_spec,
|
||||
)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
temporal_state_shape = (divide(num_v_heads,
|
||||
tp_world_size), head_k_dim, head_v_dim)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
0
vllm/model_executor/layers/mamba/ops/__init__.py
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1092
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
1092
vllm/model_executor/layers/mamba/ops/causal_conv1d.py
Normal file
File diff suppressed because it is too large
Load Diff
168
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
168
vllm/model_executor/layers/mamba/ops/layernorm_gated.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row: tl.int64,
|
||||
stride_y_row: tl.int64,
|
||||
stride_z_row: tl.int64,
|
||||
M: tl.int64, # number of rows in X
|
||||
N: tl.int64, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N, )
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N, )
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = torch.empty((ngroups * M, ), dtype=torch.float32,
|
||||
device=x.device) if not is_rms_norm else None
|
||||
rstd = torch.empty((ngroups * M, ), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError(
|
||||
"This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = _layer_norm_fwd(x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True)
|
||||
|
||||
return y.reshape(x_shape_og)
|
||||
414
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
414
vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Normal file
@@ -0,0 +1,414 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.triton_utils import HAS_TRITON, tl, triton
|
||||
|
||||
TRITON3 = HAS_TRITON and (version.parse(triton.__version__)
|
||||
>= version.parse("3.0.0"))
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics({
|
||||
"HAS_STATE_BATCH_INDICES":
|
||||
lambda args: args["state_batch_indices_ptr"] is not None
|
||||
})
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||
@triton.jit
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
A_ptr,
|
||||
B_ptr,
|
||||
C_ptr,
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch,
|
||||
stride_state_head,
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_bias_dim,
|
||||
stride_A_head,
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += (state_batch_idx * stride_state_batch +
|
||||
pid_h * stride_state_head)
|
||||
else:
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim +
|
||||
offs_n[None, :] * stride_state_dstate)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim +
|
||||
offs_n[None, :] * stride_A_dstate)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptrs,
|
||||
mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= (state_batch_idx != pad_slot_id)
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch, )
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else
|
||||
((16, 4) if dstate <= 32 else
|
||||
((8, 4) if dstate <= 64 else
|
||||
((4, 4) if dstate <= 128 else ((4, 8))))))
|
||||
tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride(
|
||||
-1) == 0 and dt_bias.stride(-1) == 0
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads // ngroups,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0),
|
||||
dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
def selective_scan_fn(u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
dt_bias: (dim,) or (dim)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended with 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state index
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3 and query_start_loc is None:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 2 and query_start_loc is not None:
|
||||
B = B.unsqueeze(0)
|
||||
if C.dim() == 3 and query_start_loc is None:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 2 and query_start_loc is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus,
|
||||
query_start_loc, cache_indices, has_initial_state,
|
||||
ssm_states, pad_slot_id)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
else:
|
||||
return z # output written inplace to z
|
||||
242
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
242
vllm/model_executor/layers/mamba/ops/ssd_bmm.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['chunk_size', 'K', 'IS_CAUSAL'],
|
||||
)
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
ngroups: tl.constexpr,
|
||||
stride_a_seqlen: tl.int64,
|
||||
stride_a_head: tl.int64,
|
||||
stride_ak: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_bk: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_outm: tl.int64,
|
||||
stride_outn: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_ch = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
a_ptr += pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen +
|
||||
offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk +
|
||||
offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# compute a * b.T
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0).to(dot_dtype)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) &
|
||||
(offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
# Zero out the results that are not from the same request
|
||||
# in the varlen batch
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||
mask=offs_n < chunk_size_limit,
|
||||
other=-2)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] +
|
||||
offs_n[None, :] * stride_outn)
|
||||
tl.store(out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) &
|
||||
(offs_n[None, :] < chunk_size))
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (seqlen, ngroups, k)
|
||||
b: (seqlen, ngroups, k)
|
||||
seq_idx: (seqlen,). out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if a.stride(-1) != 1 and a.stride(0) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(0) != 1:
|
||||
b = b.contiguous()
|
||||
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty((nchunks, ngroups, chunk_size, chunk_size),
|
||||
device=a.device,
|
||||
dtype=out_dtype)
|
||||
dot_dtype = (tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else
|
||||
(tl.float16 if a.dtype == torch.float16
|
||||
or b.dtype == torch.float16 else tl.float32))
|
||||
grid = lambda META: (triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
chunk_size, META['BLOCK_SIZE_N']), nchunks * ngroups)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a_ptr=a,
|
||||
b_ptr=b,
|
||||
out_ptr=out,
|
||||
seq_idx_ptr=seq_idx,
|
||||
seqlen=seqlen,
|
||||
chunk_size=chunk_size,
|
||||
K=k,
|
||||
ngroups=ngroups,
|
||||
stride_a_seqlen=a.stride(0),
|
||||
stride_a_head=a.stride(1),
|
||||
stride_ak=a.stride(2),
|
||||
stride_b_seqlen=b.stride(0),
|
||||
stride_b_head=b.stride(1),
|
||||
stride_bk=b.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_outm=out.stride(-2),
|
||||
stride_outn=out.stride(-1),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
IS_CAUSAL=causal,
|
||||
dot_dtype=dot_dtype,
|
||||
)
|
||||
return out
|
||||
527
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
527
vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
cb_ptr,
|
||||
x_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
C_ptr,
|
||||
states_ptr,
|
||||
D_ptr,
|
||||
initstates_ptr,
|
||||
chunk_indices_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
chunk_size: tl.constexpr,
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_cb_chunk: tl.int64,
|
||||
stride_cb_head: tl.int64,
|
||||
stride_cb_csize_m: tl.int64,
|
||||
stride_cb_csize_k: tl.constexpr,
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_z_seqlen: tl.int64,
|
||||
stride_z_head: tl.int64,
|
||||
stride_z_hdim: tl.constexpr,
|
||||
stride_out_seqlen: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_hdim: tl.constexpr,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
stride_C_seqlen: tl.int64,
|
||||
stride_C_head: tl.int64,
|
||||
stride_C_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
stride_D_head: tl.constexpr,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
if not HAS_INITSTATES:
|
||||
c_idx = pid_c
|
||||
c_off = 0
|
||||
else:
|
||||
c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += c_idx * stride_cb_chunk + (pid_h //
|
||||
nheads_ngroups_ratio) * stride_cb_head
|
||||
x_ptr += c_idx * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += c_idx * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
C_ptr += c_idx * chunk_size * stride_C_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
prev_states_ptr = states_ptr + c_idx * stride_states_chunk + pid_h * stride_states_head
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
||||
|
||||
seq_idx_ptr += c_idx * chunk_size * stride_seq_idx_seqlen
|
||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||
seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen,
|
||||
mask=c_idx >= 1,
|
||||
other=0)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states, we only need seq_idx_m to point
|
||||
# what is the current seq_idx
|
||||
|
||||
# get current seq idx
|
||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr +
|
||||
(pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, )
|
||||
|
||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||
# so this edge case is taken care of
|
||||
if ((c_off == 0) and (seq_idx_prev != seq_idx_m
|
||||
) # if a seq is changed exactly on boundary
|
||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||
):
|
||||
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = initstates_ptr + seq_idx_m * stride_init_states_batch + pid_h * stride_init_states_head
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize,
|
||||
mask=offs_m < chunk_size,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# - handle chunk state limit
|
||||
if HAS_INITSTATES:
|
||||
# have to split this if otherwise compilation will have problems
|
||||
dA_cs_m_boundary = 0.0
|
||||
|
||||
# get the c_idx for the next (logica) chunk
|
||||
c_idx_n = tl.load(
|
||||
chunk_indices_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=-1 # to trigger different chunk
|
||||
)
|
||||
|
||||
# - there are things to consider
|
||||
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
|
||||
# contribution of past states
|
||||
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
|
||||
# encroach into the next sequence, where c_off_n is the offset of the next
|
||||
# (logical) chunk.
|
||||
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
|
||||
# (logical) chunk indices.
|
||||
|
||||
if (c_idx == c_idx_n) or c_off > 0:
|
||||
|
||||
# get the next offset
|
||||
c_off_n = tl.load(chunk_offsets_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=chunk_size)
|
||||
|
||||
# in this case, adjust down the chunk_size_limit
|
||||
if c_idx == c_idx_n:
|
||||
chunk_size_limit = min(c_off_n, chunk_size_limit)
|
||||
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
||||
# to increase pointer by pid_m (it is a constant offset,
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
# - handle seq idx when HAS_INITSTATES==False
|
||||
seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Without the if (pid_c > -1), with Triton 2.1.0, I get
|
||||
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
|
||||
# With Triton 2.2.0, this works
|
||||
if IS_TRITON_22 or c_idx > -1:
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K)
|
||||
C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen +
|
||||
offs_k_dstate[None, :] * stride_C_dstate)
|
||||
|
||||
prev_states_ptrs = prev_states_ptr + (
|
||||
offs_n[None, :] * prev_states_hdim +
|
||||
offs_k_dstate[:, None] * prev_states_dstate)
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
||||
else:
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate),
|
||||
other=0.0)
|
||||
|
||||
prev_states = tl.load(prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
else:
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0)
|
||||
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
|
||||
cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m +
|
||||
offs_k[None, :] * stride_cb_csize_k)
|
||||
x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen +
|
||||
offs_n[None, :] * stride_x_hdim)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = chunk_size_limit if not IS_CAUSAL else min(
|
||||
(pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
for k in range(0, K_MAX, BLOCK_SIZE_K):
|
||||
cb = tl.load(cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) &
|
||||
(offs_k[None, :] < chunk_size - k),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k,
|
||||
other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
mask = offs_m[:, None] >= k + offs_k[None, :]
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(x_ptr.dtype.element_ty)
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0)
|
||||
acc += tl.dot(cb, x)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
if HAS_D:
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head + offs_n,
|
||||
mask=offs_n < hdim,
|
||||
other=0.0).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen +
|
||||
offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) &
|
||||
(offs_n[None, :] < hdim),
|
||||
other=0.0).to(tl.float32)
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
z_ptr += c_idx * chunk_size * stride_z_seqlen + pid_h * stride_z_head
|
||||
z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] +
|
||||
stride_z_hdim * offs_out_n[None, :])
|
||||
z = tl.load(z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim),
|
||||
other=0.0).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += c_idx * chunk_size * stride_out_seqlen + pid_h * stride_out_head
|
||||
out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] +
|
||||
offs_out_n[None, :] * stride_out_hdim)
|
||||
tl.store(out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) &
|
||||
(offs_out_n[None, :] < hdim))
|
||||
|
||||
|
||||
def _chunk_scan_fwd(
|
||||
cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
out,
|
||||
seq_idx,
|
||||
D=None,
|
||||
z=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
):
|
||||
assert seq_idx is not None, "this implementation requires seq_idx"
|
||||
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = C.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert C.shape == (seqlen, ngroups, dstate)
|
||||
assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size)
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if initial_states is not None:
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert chunk_indices is not None and chunk_offsets is not None, \
|
||||
"chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(
|
||||
headdim, META['BLOCK_SIZE_N']), nchunks
|
||||
if chunk_offsets is None else len(chunk_offsets), nheads)
|
||||
|
||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||
(0, 0, 0))
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb_ptr=cb,
|
||||
x_ptr=x,
|
||||
z_ptr=z,
|
||||
out_ptr=out,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
C_ptr=C,
|
||||
states_ptr=states,
|
||||
D_ptr=D,
|
||||
initstates_ptr=initial_states,
|
||||
chunk_indices_ptr=chunk_indices,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_indices) if chunk_indices is not None else 0,
|
||||
chunk_size=chunk_size,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_cb_chunk=cb.stride(0),
|
||||
stride_cb_head=cb.stride(1),
|
||||
stride_cb_csize_m=cb.stride(2),
|
||||
stride_cb_csize_k=cb.stride(3),
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_z_seqlen=z_strides[0],
|
||||
stride_z_head=z_strides[1],
|
||||
stride_z_hdim=z_strides[2],
|
||||
stride_out_seqlen=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_hdim=out.stride(2),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
stride_C_seqlen=C.stride(0),
|
||||
stride_C_head=C.stride(1),
|
||||
stride_C_dstate=C.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
stride_D_head=D.stride(0) if D is not None else 0,
|
||||
IS_CAUSAL=True,
|
||||
HAS_D=D is not None,
|
||||
D_HAS_HDIM=D.dim() == 2 if D is not None else True,
|
||||
HAS_Z=z is not None,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return
|
||||
724
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
724
vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
Normal file
@@ -0,0 +1,724 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE_H': 2}),
|
||||
triton.Config({'BLOCK_SIZE_H': 4}),
|
||||
triton.Config({'BLOCK_SIZE_H': 8}),
|
||||
triton.Config({'BLOCK_SIZE_H': 16}),
|
||||
triton.Config({'BLOCK_SIZE_H': 32}),
|
||||
triton.Config({'BLOCK_SIZE_H': 64}),
|
||||
],
|
||||
key=['chunk_size', 'nheads'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
dt_ptr,
|
||||
A_ptr,
|
||||
dt_bias_ptr,
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
# Matrix dimension
|
||||
seqlen,
|
||||
nheads: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
dt_min: tl.constexpr,
|
||||
dt_max: tl.constexpr,
|
||||
# Strides
|
||||
stride_dt_seqlen: tl.int64,
|
||||
stride_dt_head: tl.constexpr,
|
||||
stride_A_head: tl.constexpr,
|
||||
stride_dt_bias_head: tl.constexpr,
|
||||
stride_dt_out_head: tl.int64,
|
||||
stride_dt_out_chunk: tl.int64,
|
||||
stride_dt_out_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr,
|
||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
):
|
||||
# if dt is long, may cause problems, so use 64 bit
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=0).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=1)
|
||||
dt_ptr += pid_c * chunk_size * stride_dt_seqlen
|
||||
dt_out_ptr += pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head +
|
||||
offs_c[None, :] * stride_dt_seqlen)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head +
|
||||
offs_c[None, :] * stride_dt_out_csize)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head +
|
||||
offs_c[None, :] * stride_dA_cs_csize)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
dt = tl.load(dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) &
|
||||
(offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head,
|
||||
mask=offs_h < nheads,
|
||||
other=0.0).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
|
||||
dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt,
|
||||
0.0)
|
||||
tl.store(dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
states_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
|
||||
seq_idx_ptr += pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
|
||||
offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
|
||||
offs_k[:, None] * stride_b_seqlen)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr +
|
||||
(chunk_size - 1) * stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
seq_idx_last = tl.load(seq_idx_ptr +
|
||||
(chunk_size_limit - 1) * stride_seq_idx_seqlen)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
seq_idx_k = tl.load(seq_idx_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=-1)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
scale = tl.where(seq_idx_k == seq_idx_last,
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
offs_n[None, :] * stride_states_dstate)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 64
|
||||
},
|
||||
num_stages=3,
|
||||
num_warps=8),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 256,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 128,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 128,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=4),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 32,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 32,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=5,
|
||||
num_warps=2),
|
||||
triton.Config(
|
||||
{
|
||||
'BLOCK_SIZE_M': 64,
|
||||
'BLOCK_SIZE_N': 64,
|
||||
'BLOCK_SIZE_K': 32
|
||||
},
|
||||
num_stages=4,
|
||||
num_warps=2),
|
||||
],
|
||||
key=['hdim', 'dstate', 'chunk_size'],
|
||||
)
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
chunk_states_ptr,
|
||||
cu_seqlens_ptr,
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim: tl.constexpr,
|
||||
dstate: tl.constexpr,
|
||||
chunk_size: tl.constexpr,
|
||||
nheads_ngroups_ratio: tl.constexpr,
|
||||
# Strides
|
||||
stride_x_seqlen: tl.int64,
|
||||
stride_x_head: tl.int64,
|
||||
stride_x_hdim: tl.constexpr,
|
||||
stride_b_seqlen: tl.int64,
|
||||
stride_b_head: tl.int64,
|
||||
stride_b_dstate: tl.constexpr,
|
||||
stride_dt_head: tl.int64,
|
||||
stride_dt_chunk: tl.int64,
|
||||
stride_dt_csize: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_chunk_states_chunk: tl.int64,
|
||||
stride_chunk_states_head: tl.int64,
|
||||
stride_chunk_states_hdim: tl.int64,
|
||||
stride_chunk_states_dstate: tl.constexpr,
|
||||
stride_states_batch: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_hdim: tl.int64,
|
||||
stride_states_dstate: tl.constexpr,
|
||||
stride_init_states_batch: tl.int64,
|
||||
stride_init_states_head: tl.int64,
|
||||
stride_init_states_hdim: tl.int64,
|
||||
stride_init_states_dstate: tl.constexpr,
|
||||
# Meta-parameters
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += pid_c * chunk_size * stride_b_seqlen + (
|
||||
pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
# are boundary conditions at a chunk boundary) and initstates (which are boundary
|
||||
# conditions when a new example in a cont batch starts)
|
||||
initstates_ptr += pid_h * stride_init_states_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim +
|
||||
offs_k[None, :] * stride_x_seqlen)
|
||||
b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate +
|
||||
offs_k[:, None] * stride_b_seqlen)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) *
|
||||
stride_dA_cs_csize).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
||||
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_k[None, :] < chunk_size_limit - k) &
|
||||
(offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0)
|
||||
b = tl.load(b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) &
|
||||
(offs_n[None, :] < dstate) &
|
||||
(offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs,
|
||||
mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k,
|
||||
other=0.0).to(tl.float32)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if ((start_idx < pid_c * chunk_size) # first chunk
|
||||
or (HAS_INITSTATES)):
|
||||
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim +
|
||||
offs_n[None, :] * stride_chunk_states_dstate)
|
||||
else:
|
||||
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim +
|
||||
offs_n[None, :] * stride_chunk_states_dstate)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch +
|
||||
offs_m[:, None] * stride_init_states_hdim +
|
||||
offs_n[None, :] * stride_init_states_dstate)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(dA_cumsum_ptr +
|
||||
(start_idx - pid_c * chunk_size -
|
||||
1) * stride_dA_cs_csize).to(
|
||||
tl.float32)
|
||||
|
||||
past_states = tl.load(past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) &
|
||||
(offs_n[None, :] < dstate),
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (offs_m[:, None] * stride_states_hdim +
|
||||
offs_n[None, :] * stride_states_dstate)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf"))):
|
||||
seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads, )
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, )
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
dt_out = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
dA_cumsum = torch.empty(nheads,
|
||||
nchunks,
|
||||
chunk_size,
|
||||
device=dt.device,
|
||||
dtype=torch.float32)
|
||||
grid_chunk_cs = lambda META: (nchunks,
|
||||
triton.cdiv(nheads, META['BLOCK_SIZE_H']))
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt_ptr=dt,
|
||||
A_ptr=A,
|
||||
dt_bias_ptr=dt_bias,
|
||||
dt_out_ptr=dt_out,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seqlen=seqlen,
|
||||
nheads=nheads,
|
||||
chunk_size=chunk_size,
|
||||
dt_min=dt_limit[0],
|
||||
dt_max=dt_limit[1],
|
||||
stride_dt_seqlen=dt.stride(0),
|
||||
stride_dt_head=dt.stride(1),
|
||||
stride_A_head=A.stride(0),
|
||||
stride_dt_bias_head=dt_bias.stride(0)
|
||||
if dt_bias is not None else 0,
|
||||
stride_dt_out_head=dt_out.stride(0),
|
||||
stride_dt_out_chunk=dt_out.stride(1),
|
||||
stride_dt_out_csize=dt_out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
DT_SOFTPLUS=dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx=None,
|
||||
states=None,
|
||||
states_in_fp32=True):
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
|
||||
assert seq_idx is not None
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
|
||||
if states is not None:
|
||||
assert states.shape == (nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty((nchunks, nheads, headdim, dstate),
|
||||
device=x.device,
|
||||
dtype=states_dtype)
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), nchunks, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
states_ptr=states,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
seq_idx_ptr=seq_idx,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
seqlen=seqlen,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
chunk_states,
|
||||
initial_states=None):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
batch = cu_seqlens.shape[0] - 1
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (total_seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3))
|
||||
if initial_states is not None else (0, 0, 0, 0))
|
||||
|
||||
grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.
|
||||
cdiv(dstate, META['BLOCK_SIZE_N']), batch, nheads)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x_ptr=x,
|
||||
b_ptr=B,
|
||||
dt_ptr=dt,
|
||||
dA_cumsum_ptr=dA_cumsum,
|
||||
chunk_states_ptr=chunk_states,
|
||||
cu_seqlens_ptr=cu_seqlens,
|
||||
states_ptr=states,
|
||||
initstates_ptr=initial_states,
|
||||
hdim=headdim,
|
||||
dstate=dstate,
|
||||
chunk_size=chunk_size,
|
||||
nheads_ngroups_ratio=nheads // ngroups,
|
||||
stride_x_seqlen=x.stride(0),
|
||||
stride_x_head=x.stride(1),
|
||||
stride_x_hdim=x.stride(2),
|
||||
stride_b_seqlen=B.stride(0),
|
||||
stride_b_head=B.stride(1),
|
||||
stride_b_dstate=B.stride(2),
|
||||
stride_dt_head=dt.stride(0),
|
||||
stride_dt_chunk=dt.stride(1),
|
||||
stride_dt_csize=dt.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_chunk_states_chunk=chunk_states.stride(0),
|
||||
stride_chunk_states_head=chunk_states.stride(1),
|
||||
stride_chunk_states_hdim=chunk_states.stride(2),
|
||||
stride_chunk_states_dstate=chunk_states.stride(3),
|
||||
stride_states_batch=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_hdim=states.stride(2),
|
||||
stride_states_dstate=states.stride(3),
|
||||
stride_init_states_batch=initial_states_strides[0],
|
||||
stride_init_states_head=initial_states_strides[1],
|
||||
stride_init_states_hdim=initial_states_strides[2],
|
||||
stride_init_states_dstate=initial_states_strides[3],
|
||||
HAS_INITSTATES=initial_states is not None)
|
||||
return states
|
||||
238
vllm/model_executor/layers/mamba/ops/ssd_combined.py
Normal file
238
vllm/model_executor/layers/mamba/ops/ssd_combined.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import (_chunk_cumsum_fwd, _chunk_state_fwd,
|
||||
chunk_state_varlen)
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0')
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
seqlen, nheads, headdim = x.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (seqlen, ngroups, dstate)
|
||||
assert dt.shape == (seqlen, nheads)
|
||||
assert A.shape == (nheads, )
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads, )
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (seqlen, )
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if x.stride(-1) != 1 and x.stride(
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if z is not None and z.stride(-1) != 1 and z.stride(
|
||||
0) != 1: # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
assert cu_seqlens is not None, "Assuming varlen input - must supply cu_seqlens"
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (len(cu_seqlens) - 1, nheads, headdim,
|
||||
dstate)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
# which has a minimal implementation to understand the below operations
|
||||
# - as explained by the blog, mamba is a special case of causal attention
|
||||
# - the idea is to chunk the attention matrix and compute each
|
||||
# submatrix separately using different optimizations.
|
||||
# - see the blog and paper for a visualization of the submatrices
|
||||
# which we refer to in the comments below
|
||||
|
||||
# 1. Compute chunked cumsum of A * dt
|
||||
# - here dt may go through a softplus activation
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
|
||||
A,
|
||||
chunk_size,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
states = _chunk_state_fwd(B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx=seq_idx,
|
||||
states_in_fp32=True)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states
|
||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum, # (nheads, nchunks, chunk_size)
|
||||
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None else
|
||||
None, # (batch, nheads, headdim*dstate)
|
||||
seq_idx=seq_idx,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
chunk_offsets=chunk_offsets)
|
||||
states = rearrange(states, "... (p n) -> ... p n", n=dstate)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C,
|
||||
B,
|
||||
chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
# account past causal states.
|
||||
# - if initial states are provided, then states information will be
|
||||
# augmented with initial_states.
|
||||
# - to do this properly, we need to account for example changes in
|
||||
# the continuous batch, therefore we introduce pseudo chunks, which is
|
||||
# a chunk that is split up each time an example changes.
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
_chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
out, # in-place update
|
||||
seq_idx,
|
||||
D=D,
|
||||
z=z,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
varlen_states = chunk_state_varlen(
|
||||
B,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
cu_seqlens,
|
||||
states,
|
||||
initial_states=initial_states,
|
||||
)
|
||||
|
||||
return varlen_states
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined_varlen(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
cu_seqlens,
|
||||
seq_idx,
|
||||
out,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (seqlen, nheads, headdim)
|
||||
dt: (seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (seqlen, ngroups, dstate)
|
||||
C: (seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
seq_idx: (seqlen)
|
||||
cu_seqlens: (batch + 1)
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: (seqlen, nheads, headdim) preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
Return:
|
||||
varlen_states: (batch, nheads, headdim, dstate)
|
||||
"""
|
||||
|
||||
assert cu_seqlens is not None, "cu_seqlens must be provided assuming varlen input"
|
||||
assert seq_idx is not None
|
||||
|
||||
varlen_states = _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
out,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
state_dtype=state_dtype)
|
||||
|
||||
return varlen_states
|
||||
200
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
200
vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_SIZE': 64}),
|
||||
triton.Config({'BLOCK_SIZE': 128}),
|
||||
triton.Config({'BLOCK_SIZE': 256}),
|
||||
triton.Config({'BLOCK_SIZE': 512}),
|
||||
triton.Config({'BLOCK_SIZE': 1024}),
|
||||
triton.Config({'BLOCK_SIZE': 2048}),
|
||||
],
|
||||
key=['dim'],
|
||||
)
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr,
|
||||
out_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
dim: tl.constexpr,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size: tl.constexpr,
|
||||
# Strides
|
||||
stride_states_chunk: tl.int64,
|
||||
stride_states_head: tl.int64,
|
||||
stride_states_dim: tl.constexpr,
|
||||
stride_out_chunk: tl.int64,
|
||||
stride_out_head: tl.int64,
|
||||
stride_out_dim: tl.constexpr,
|
||||
stride_dA_cs_head: tl.int64,
|
||||
stride_dA_cs_chunk: tl.int64,
|
||||
stride_dA_cs_csize: tl.constexpr,
|
||||
stride_initstates_batch: tl.int64,
|
||||
stride_initstates_head: tl.int64,
|
||||
stride_initstates_dim: tl.constexpr,
|
||||
stride_seq_idx_seqlen: tl.constexpr,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid_h = tl.program_id(axis=1)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_h * stride_states_head
|
||||
dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size -
|
||||
1) * stride_dA_cs_csize
|
||||
out_ptr += pid_h * stride_out_head
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptr += pid_h * stride_initstates_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
|
||||
# - states will be the past state of the sequence that continues on the current check
|
||||
if not HAS_INITSTATES:
|
||||
states = tl.zeros((BLOCK_SIZE, ), dtype=tl.float32)
|
||||
else:
|
||||
initstates_ptr += offs_m * stride_initstates_dim
|
||||
initstates_ptrs = initstates_ptr
|
||||
# - for cont batches, for the first chunk mean it will be the first batch's
|
||||
# init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks - 1):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale_mask = True
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(seq_idx_ptr +
|
||||
(min((c + 1) * chunk_size, seqlen) - 1) *
|
||||
stride_seq_idx_seqlen)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
if prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim,
|
||||
other=0.0).to(tl.float32)
|
||||
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(seq_idx_ptr +
|
||||
min(c * chunk_size, seqlen) *
|
||||
stride_seq_idx_seqlen)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize +
|
||||
(c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
initial_states=None,
|
||||
out_dtype=None,
|
||||
):
|
||||
nchunks, nheads, dim = states.shape
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (nheads, nchunks, chunk_size)
|
||||
seqlen = seq_idx.shape[-1]
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty((nchunks, nheads, dim),
|
||||
device=states.device,
|
||||
dtype=out_dtype)
|
||||
|
||||
initial_states_strides = ((initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2))
|
||||
if initial_states is not None else (0, 0, 0))
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE']), nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states_ptr=states,
|
||||
out_ptr=out,
|
||||
dA_cs_ptr=dA_cumsum,
|
||||
initstates_ptr=initial_states,
|
||||
seq_idx_ptr=seq_idx,
|
||||
chunk_offsets_ptr=chunk_offsets,
|
||||
chunk_meta_num=len(chunk_offsets)
|
||||
if chunk_offsets is not None else 0,
|
||||
dim=dim,
|
||||
nchunks=nchunks,
|
||||
seqlen=seqlen if seq_idx is not None else 0,
|
||||
chunk_size=chunk_size if seq_idx is not None else 0,
|
||||
stride_states_chunk=states.stride(0),
|
||||
stride_states_head=states.stride(1),
|
||||
stride_states_dim=states.stride(2),
|
||||
stride_out_chunk=out.stride(0),
|
||||
stride_out_head=out.stride(1),
|
||||
stride_out_dim=out.stride(2),
|
||||
stride_dA_cs_head=dA_cumsum.stride(0),
|
||||
stride_dA_cs_chunk=dA_cumsum.stride(1),
|
||||
stride_dA_cs_csize=dA_cumsum.stride(2),
|
||||
stride_initstates_batch=initial_states_strides[0],
|
||||
stride_initstates_head=initial_states_strides[1],
|
||||
stride_initstates_dim=initial_states_strides[2],
|
||||
stride_seq_idx_seqlen=seq_idx.stride(0),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out
|
||||
253
vllm/model_executor/layers/mamba/short_conv.py
Normal file
253
vllm/model_executor/layers/mamba/short_conv.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
|
||||
causal_conv1d_fn, causal_conv1d_update)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionMetadata)
|
||||
|
||||
|
||||
@CustomOp.register("short_conv")
|
||||
class ShortConv(MambaBase, CustomOp):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
dim: int,
|
||||
layer_idx: int,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.conv_dim = dim
|
||||
self.L_cache = config.conv_L_cache
|
||||
self.bias = config.conv_bias
|
||||
|
||||
self.conv = ColumnParallelLinear(
|
||||
input_size=self.L_cache,
|
||||
output_size=dim,
|
||||
bias=self.bias,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `set_weight_attrs`
|
||||
# doesn't allow to override it
|
||||
self.conv.weight.data = self.conv.weight.data.unsqueeze(1)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=dim,
|
||||
output_sizes=[dim] * 3,
|
||||
bias=self.bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
input_size=dim,
|
||||
output_size=dim,
|
||||
bias=self.bias,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.kv_cache = (torch.tensor([]), )
|
||||
|
||||
self.model_config = model_config
|
||||
self.cache_config = cache_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
return
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
torch.ops.vllm.short_conv(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
forward_context = get_forward_context()
|
||||
# ShortConvAttentionMetadata contains metadata necessary for the
|
||||
# short_conv triton kernels to operate in continuous batching and in
|
||||
# chunked prefill modes; they are computed at top-level model forward
|
||||
# since they stay the same and reused for all mamba layers in the same
|
||||
# iteration.
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, ShortConvAttentionMetadata)
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
has_initial_states_p = attn_metadata.has_initial_states
|
||||
|
||||
BCx, _ = self.in_proj(hidden_states)
|
||||
|
||||
B, C, x = BCx.chunk(3, dim=-1)
|
||||
|
||||
conv_weights = self.conv.weight.view(self.conv.weight.size(0),
|
||||
self.conv.weight.size(2))
|
||||
|
||||
if attn_metadata is None:
|
||||
# V1 profile run
|
||||
Bx = (B * x).contiguous()
|
||||
hidden_states = C * Bx
|
||||
contextualized_states, _ = self.out_proj(hidden_states)
|
||||
return contextualized_states
|
||||
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_decodes + num_prefill_tokens
|
||||
|
||||
# NOTE: V1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
B_d, B_p = torch.split(
|
||||
B[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
C_d, C_p = torch.split(
|
||||
C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
x_d, x_p = torch.split(
|
||||
x[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
conv_output_list = []
|
||||
|
||||
if has_prefill:
|
||||
Bx_p = (B_p * x_p).transpose(0, 1)
|
||||
Bx = causal_conv1d_fn(Bx_p,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=state_indices_tensor_p,
|
||||
metadata=attn_metadata,
|
||||
query_start_loc=query_start_loc_p).transpose(
|
||||
0, 1)[:num_prefill_tokens]
|
||||
|
||||
y = C_p * Bx
|
||||
conv_output_list.append(y)
|
||||
|
||||
if has_decode:
|
||||
Bx_d = (B_d * x_d).contiguous()
|
||||
Bx = causal_conv1d_update(
|
||||
Bx_d,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv.bias,
|
||||
activation=None,
|
||||
conv_state_indices=state_indices_tensor_d)
|
||||
y = C_d * Bx
|
||||
conv_output_list.insert(0, y)
|
||||
|
||||
# Merge prefill and decode outputs before passing to gated MLP
|
||||
hidden_states = torch.vstack(conv_output_list)
|
||||
|
||||
# Final linear projection
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||
assert self.model_config is not None
|
||||
assert self.cache_config is not None
|
||||
return MambaStateDtypeCalculator.short_conv_state_dtype(
|
||||
self.model_config.dtype,
|
||||
self.cache_config.mamba_cache_dtype,
|
||||
)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...]]:
|
||||
return MambaStateShapeCalculator.short_conv_state_shape(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.conv_dim,
|
||||
conv_kernel=self.L_cache,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "short_conv"
|
||||
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
from vllm.v1.attention.backends.short_conv_attn import (
|
||||
ShortConvAttentionBackend)
|
||||
return ShortConvAttentionBackend
|
||||
|
||||
|
||||
def short_conv(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states, output=output)
|
||||
|
||||
|
||||
def short_conv_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="short_conv",
|
||||
op_func=short_conv,
|
||||
mutates_args=["output"],
|
||||
fake_impl=short_conv_fake,
|
||||
)
|
||||
Reference in New Issue
Block a user