Add mamba kernel (#10234)
This commit is contained in:
@@ -34,6 +34,7 @@ from sgl_kernel.elementwise import (
|
||||
rmsnorm,
|
||||
silu_and_mul,
|
||||
)
|
||||
from sgl_kernel.mamba import causal_conv1d_fwd, causal_conv1d_update
|
||||
|
||||
if torch.version.hip is not None:
|
||||
from sgl_kernel.elementwise import gelu_quick
|
||||
|
||||
50
sgl-kernel/python/sgl_kernel/mamba.py
Normal file
50
sgl-kernel/python/sgl_kernel/mamba.py
Normal file
@@ -0,0 +1,50 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# mamba
|
||||
def causal_conv1d_fwd(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
conv_states: Optional[torch.Tensor],
|
||||
query_start_loc: Optional[torch.Tensor],
|
||||
cache_indices: Optional[torch.Tensor],
|
||||
has_initial_state: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
pad_slot_id: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.causal_conv1d_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias_,
|
||||
conv_states,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
)
|
||||
|
||||
|
||||
def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_: Optional[torch.Tensor],
|
||||
silu_activation: bool,
|
||||
cache_seqlens: Optional[torch.Tensor],
|
||||
conv_state_indices: Optional[torch.Tensor],
|
||||
pad_slot_id: int,
|
||||
):
|
||||
torch.ops.sgl_kernel.causal_conv1d_update(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias_,
|
||||
silu_activation,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
pad_slot_id,
|
||||
)
|
||||
Reference in New Issue
Block a user