Add mamba kernel (#10234)

This commit is contained in:
Yi Zhang
2025-09-10 03:58:43 +08:00
committed by GitHub
parent 8471e5e616
commit 8cbe1538ef
8 changed files with 1418 additions and 0 deletions

View File

@@ -34,6 +34,7 @@ from sgl_kernel.elementwise import (
rmsnorm,
silu_and_mul,
)
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
if torch.version.hip is not None:
from sgl_kernel.elementwise import gelu_quick

View File

@@ -0,0 +1,50 @@
from typing import Optional
import torch
# mamba
def causal_conv1d_fwd(
x: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
query_start_loc: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool,
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_fwd(
x,
weight,
bias_,
conv_states,
query_start_loc,
cache_indices,
has_initial_state,
silu_activation,
pad_slot_id,
)
def causal_conv1d_update(
x: torch.Tensor,
conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int,
):
torch.ops.sgl_kernel.causal_conv1d_update(
x,
conv_state,
weight,
bias_,
silu_activation,
cache_seqlens,
conv_state_indices,
pad_slot_id,
)