model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
Netanel Haber
2025-10-08 19:37:38 +03:00
committed by GitHub
parent c882b5ae75
commit d6837aea4d
35 changed files with 3280 additions and 854 deletions

View File

@@ -1,7 +1,14 @@
import logging
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
# evade circular imports
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.model_runner import ModelRunner
ATTENTION_BACKENDS = {}
@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend):
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
"""
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert not (
runner.is_hybrid_gdn and runner.use_mla_backend
runner.hybrid_gdn_config is not None and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models
if runner.is_hybrid_gdn:
if cfg := runner.mambaish_config:
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend,
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
if runner.hybrid_gdn_config is not None:
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
)
full_attn_layers = cfg.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)

View File

@@ -181,6 +181,45 @@ def _layer_norm_fwd(
return out, mean, rstd
def rms_norm_gated(
*,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
class LayerNormFn(torch.autograd.Function):
@staticmethod
@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
return rms_norm_gated(
x=x,
weight=weight,
bias=bias,
eps=eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
def layernorm_fn(
@@ -238,14 +261,6 @@ def layernorm_fn(
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNorm(torch.nn.Module):
def __init__(
@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
is_rms_norm=False,
)
@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
return layernorm_fn(
x,
self.weight,
self.bias,
@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
is_rms_norm=True,
)

View File

@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
ForwardMetadata,
Mamba2Metadata,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu
@@ -47,18 +54,10 @@ elif is_npu():
causal_conv1d_update = causal_conv1d_update_npu
@dataclass
class ForwardMetadata:
query_start_loc: Optional[torch.Tensor]
mamba_cache_indices: torch.Tensor
class MambaAttnBackend(AttentionBackend):
"""Attention backend using Mamba kernel."""
class MambaAttnBackendBase(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.pad_slot_id = -1 # Default pad slot id
self.pad_slot_id = PAD_SLOT_ID
self.device = model_runner.device
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
self.forward_metadata: ForwardMetadata = None
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
def _forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
forward_batch.req_pool_indices
)
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
self.forward_metadata = self._forward_metadata(forward_batch)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.forward_metadata = self._capture_metadata(
bs, req_pool_indices, forward_mode
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.forward_metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
assert (
max_num_tokens % max_bs == 0
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
def _capture_metadata(
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
):
if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}")
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
def init_forward_metadata_replay_cuda_graph(
def _replay_metadata(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1 # Mamba attn does not use seq lens to index kv cache
class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
dt_bias = kwargs["dt_bias"]
layer_id = kwargs["layer_id"]
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = mamba_cache_params.conv
ssm_states = mamba_cache_params.temporal
if is_target_verify:
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = self.req_to_token_pool.get_mamba_params(layer_id)
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
intermediate_state_cache = mamba_cache_params.intermediate_ssm
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool,
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
)
conv_states_to_use = conv_states.clone()
else:
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
has_initial_states = forward_batch.extend_prefix_lens > 0
conv_states_to_use = conv_states
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
return core_attn_out
class Mamba2AttnBackend(MambaAttnBackendBase):
"""Attention backend wrapper for Mamba2Mixer kernels."""
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
config = model_runner.mamba2_config
assert config is not None
self.mamba_chunk_size = config.mamba_chunk_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata = self._forward_metadata(forward_batch)
self.forward_metadata = Mamba2Metadata.prepare_mixed(
metadata.query_start_loc,
metadata.mamba_cache_indices,
self.mamba_chunk_size,
forward_batch,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def forward(
self,
mixer: MambaMixer2,
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_id: int,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
assert isinstance(self.forward_metadata, Mamba2Metadata)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
return mixer.forward(
hidden_states=hidden_states,
output=output,
layer_cache=layer_cache,
metadata=self.forward_metadata,
mup_vector=mup_vector,
use_triton_causal_conv=use_triton_causal_conv,
)
def forward_decode(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
def forward_extend(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
class HybridLinearAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
"""Manages a full and linear attention backend"""
def __init__(
self,
full_attn_backend: AttentionBackend,
linear_attn_backend: AttentionBackend,
linear_attn_backend: MambaAttnBackendBase,
full_attn_layers: list[int],
):
self.full_attn_layers = full_attn_layers
self.full_attn_backend = full_attn_backend
self.linear_attn_backend = linear_attn_backend
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
def init_forward_metadata(self, forward_batch: ForwardBatch):
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
)
def get_cuda_graph_seq_len_fill_value(self):
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
def forward_decode(
self,
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
):
layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_decode(
return self.full_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
return self.attn_backend_list[1].forward_decode(
return self.linear_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
):
layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_extend(
return self.full_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
return self.attn_backend_list[1].forward_extend(
return self.linear_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0]
state_indices_tensor = self.attn_backend_list[
1
].forward_metadata.mamba_cache_indices[:request_number]
state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
:request_number
]
)
mamba_caches = self.attn_backend_list[
1
].req_to_token_pool.get_mamba_params_all_layers()
mamba_caches = (
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
)
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = mamba_caches
conv_states = mamba_caches.conv
ssm_states = mamba_caches.temporal
intermediate_state_cache = mamba_caches.intermediate_ssm
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
# SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0

View File

@@ -10,7 +10,7 @@ import torch
from sgl_kernel import causal_conv1d_fwd
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
PAD_SLOT_ID = -1
from .causal_conv1d_triton import PAD_SLOT_ID
def causal_conv1d_fn(

View File

@@ -6,11 +6,11 @@ from typing import List, Optional, Union
import numpy as np
import torch
PAD_SLOT_ID = -1
import triton
import triton.language as tl
PAD_SLOT_ID = -1
@triton.jit()
def _causal_conv1d_fwd_kernel( # continuous batching
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
+ (conv_state_batch_coord * stride_conv_state_seq)
+ conv_state_token_offset * stride_conv_state_tok
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
:, None
]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_state_batch_coord < num_cache_lines)
@@ -897,7 +899,10 @@ def causal_conv1d_update(
stride_state_indices = (
conv_state_indices.stride(0) if conv_state_indices is not None else 0
)
state_len = width - 1 + (seqlen - 1) # effective state_len needed
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):

View File

@@ -1,23 +1,30 @@
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.custom_op import CustomOp
from sglang.srt.configs.mamba_utils import (
Mamba2CacheParams,
extra_groups_for_head_shards,
)
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn as causal_conv1d_fn_triton,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_update as causal_conv1d_update_triton,
)
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
from sglang.srt.layers.attention.mamba.ops import (
mamba_chunk_scan_combined,
selective_state_update,
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.mem_cache.memory_pool import MambaPool
from sglang.srt.model_loader.weight_utils import (
composed_weight_loader,
sharded_weight_loader,
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
return loader
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return layernorm_fn(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
class MambaMixer2(torch.nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
def __init__(
self,
cache_params: Mamba2CacheParams,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
chunk_size: int,
layer_id: int,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
# cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.num_heads = num_heads = cache_params.shape.num_heads
self.head_dim = cache_params.shape.head_dim
assert (
num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
"then num_groups must equal 1."
)
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = cache_params.shape.ssm_state_size
self.activation = activation
self.layer_id = layer_id
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
self.chunk_size = chunk_size
conv_kernel_size = cache_params.shape.conv_kernel
self.intermediate_size = intermediate_size = (
cache_params.shape.intermediate_size
)
self.n_groups = n_groups
if n_groups % self.tp_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size
)
groups = extra_groups_for_head_shards(n_groups, self.tp_size)
self.n_groups = n_groups + groups
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
self.conv_dim = cache_params.shape.conv_dim
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
if n_groups % self.tp_size == 0:
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
prefix=f"{prefix}.in_proj",
)
if n_groups % self.tp_size != 0:
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
):
pass
def forward(
self,
*,
hidden_states: torch.Tensor,
output: torch.Tensor,
forward_batch: ForwardBatch,
layer_cache: MambaPool.State,
metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
# attn_backend_list[-1] gives access to MambaAttnBackend
mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
attn_metadata = mamba_backend.forward_metadata
state_indices_tensor = attn_metadata.mamba_cache_indices
chunk_size = self.chunk_size
# metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
state_indices_tensor = metadata.mamba_cache_indices
conv_state = layer_cache.conv
ssm_state = layer_cache.temporal
conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
self.layer_id
)
assert (
ssm_state.size(1) == self.ssm_state_size
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
query_start_loc = attn_metadata.query_start_loc
chunk_size = self.chunk_size
# TODO: properly support this
prep_initial_states = False
query_start_loc = metadata.query_start_loc
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
dim=-1,
)
num_prefills = metadata.num_prefills # request count
num_decodes = metadata.num_decodes # token count (=request)
num_prefill_tokens = metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
assert num_actual_tokens == projected_states.shape[0]
# NOTE: V0 put prefill before decode
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
projected_states.shape[0],
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if forward_batch.forward_mode.is_extend():
if has_prefill:
mixed_metadata = metadata.mixed_metadata
assert mixed_metadata is not None
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
num_prefill_tokens = forward_batch.extend_num_tokens or 0
has_initial_states = forward_batch.extend_prefix_lens > 0
cache_indices = attn_metadata.mamba_cache_indices
x = hidden_states_B_C.transpose(
has_initial_states_p = mixed_metadata.has_initial_states
prep_initial_states = mixed_metadata.prep_initial_states
cache_indices = state_indices_tensor_p
x = hidden_states_B_C_p.transpose(
0, 1
) # this is the form that causal-conv see
hidden_states_B_C = causal_conv1d_fn(
ccfn = (
causal_conv1d_fn
if not use_triton_causal_conv
else causal_conv1d_fn_triton
)
hidden_states_B_C_p = ccfn(
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_states,
has_initial_state=has_initial_states_p,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
query_start_loc=query_start_loc_p,
seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if has_initial_states is not None and prep_initial_states:
if has_initial_states_p is not None and prep_initial_states:
initial_states = torch.where(
has_initial_states[:, None, None, None],
ssm_state[state_indices_tensor],
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p],
0,
)
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states.view(
hidden_states_p.view(
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
),
dt.unsqueeze(0),
dt_p.unsqueeze(0),
self.A,
B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=mixed_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
cu_seqlens=query_start_loc,
seq_idx=mixed_metadata.seq_idx,
chunk_indices=mixed_metadata.chunk_indices,
chunk_offsets=mixed_metadata.chunk_offsets,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
out=preallocated_ssm_out_p.view(
1, num_prefill_tokens, -1, self.head_dim
),
state_dtype=ssm_state.dtype,
)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
elif forward_batch.forward_mode.is_decode():
num_decodes = len(query_start_loc) - 1
ssm_state[state_indices_tensor_p] = varlen_state
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
ccu = (
causal_conv1d_update
if not use_triton_causal_conv
else causal_conv1d_update_triton
)
hidden_states_B_C_d = ccu(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor,
conv_state_indices=state_indices_tensor_d,
)
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A = (
A_d = (
self.A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states = hidden_states.view(
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# - layer_state.ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state.permute(0, 3, 2, 1),
hidden_states,
dt,
A,
B,
C,
D,
ssm_state,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
elif forward_batch.forward_mode.is_idle():
preallocated_ssm_out = preallocated_ssm_out
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out, gate)
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:], _ = self.out_proj(hidden_states)
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
@property
def mamba_type(self) -> str:

View File

@@ -0,0 +1,211 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
import math
from dataclasses import dataclass
import torch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@dataclass(kw_only=True)
class ForwardMetadata:
query_start_loc: torch.Tensor
mamba_cache_indices: torch.Tensor
@dataclass(kw_only=True)
class Mamba2Metadata(ForwardMetadata):
"""stable metadata across all mamba2 layers in the forward pass"""
num_prefills: int
num_prefill_tokens: int
num_decodes: int
@dataclass(kw_only=True, frozen=True)
class MixedMetadata:
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
extend_seq_lens_cpu: list[int]
mixed_metadata: MixedMetadata | None = None
"""`mixed_metadata` is used for extend/mixed requests"""
@staticmethod
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = (
math.ceil(total_seqlens / chunk_size)
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
)
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
chunk_offsets = torch.zeros(
(N,), dtype=torch.int, device=query_start_loc.device
)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += s % chunk_size > 0
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
@staticmethod
def prepare_decode(
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
seq_lens: torch.Tensor,
) -> "Mamba2Metadata":
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_decodes=len(seq_lens),
num_prefills=0,
num_prefill_tokens=0,
)
@classmethod
def prepare_mixed(
cls,
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
chunk_size: int,
forward_batch: ForwardBatch,
) -> "Mamba2Metadata":
"""This path cannot run with CUDA graph, as it contains extend requests."""
if forward_batch.extend_num_tokens is None:
return cls.prepare_decode(
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
)
num_prefills = len(forward_batch.extend_seq_lens)
num_prefill_tokens = forward_batch.extend_num_tokens
num_decodes = len(forward_batch.seq_lens) - num_prefills
context_lens_tensor = forward_batch.extend_prefix_lens
assert context_lens_tensor is not None
# precompute flag to avoid device syncs later
has_initial_states = context_lens_tensor > 0
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
query_start_loc = query_start_loc[: num_prefills + 1]
seq_idx = torch.repeat_interleave(
torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device
),
query_start_loc.diff(),
output_size=num_prefill_tokens,
)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
chunk_offsets, chunk_indices = None, None
if prep_initial_states:
chunk_indices, chunk_offsets = (
cls._query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens
)
)
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
mixed_metadata=cls.MixedMetadata(
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
),
)

View File

@@ -1,81 +0,0 @@
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
from sglang.srt.distributed.utils import divide
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups

View File

@@ -0,0 +1,120 @@
from typing import Union
import torch
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
from sglang.srt.utils.common import set_weight_attrs
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * torch.nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(
x=x,
weight=self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
is_rms_norm=True,
)

View File

@@ -15,56 +15,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "K", "IS_CAUSAL"],
# )
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices

View File

@@ -16,66 +16,6 @@ from packaging import version
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
# )
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices

View File

@@ -17,17 +17,6 @@ import triton.language as tl
from .mamba_ssm import softplus
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE_H": 2}),
# triton.Config({"BLOCK_SIZE_H": 4}),
# triton.Config({"BLOCK_SIZE_H": 8}),
# triton.Config({"BLOCK_SIZE_H": 16}),
# triton.Config({"BLOCK_SIZE_H": 32}),
# triton.Config({"BLOCK_SIZE_H": 64}),
# ],
# key=["chunk_size", "nheads"],
# )
@triton.jit
def _chunk_cumsum_fwd_kernel(
# Pointers to matrices
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit
def _chunk_state_fwd_kernel(
# Pointers to matrices
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
tl.store(states_ptrs, states, mask=c_mask)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices

View File

@@ -13,17 +13,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE": 64}),
# triton.Config({"BLOCK_SIZE": 128}),
# triton.Config({"BLOCK_SIZE": 256}),
# triton.Config({"BLOCK_SIZE": 512}),
# triton.Config({"BLOCK_SIZE": 1024}),
# triton.Config({"BLOCK_SIZE": 2048}),
# ],
# key=["dim"],
# )
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices

View File

@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
if model_runner.is_hybrid_gdn:
if model_runner.hybrid_gdn_config is not None:
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else: