model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
Netanel Haber
2025-10-08 19:37:38 +03:00
committed by GitHub
parent c882b5ae75
commit d6837aea4d
35 changed files with 3280 additions and 854 deletions

View File

@@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
@@ -32,4 +33,5 @@ __all__ = [
"DotsVLMConfig",
"DotsOCRConfig",
"FalconH1Config",
"NemotronHConfig",
]

View File

@@ -15,16 +15,12 @@
"""Falcon-H1 model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
get_tensor_model_parallel_world_size,
@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig):
self.rope_scaling = None
self.rope_scaling = rope_scaling
self.projectors_bias = projectors_bias
mamba_intermediate = (
self.mamba_intermediate = mamba_intermediate = (
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
)
@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig):
def layers_block_type(self):
return ["falcon_h1" for i in range(self.num_hidden_layers)]
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
@property
def full_attention_layer_ids(self):
# For Falcon-H1, we do have attention on all layers
@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig):
return range(self.num_hidden_layers)
@property
def hybrid_gdn_params(self):
world_size = get_tensor_model_parallel_world_size()
n_groups = self.mamba_n_groups
if self.mamba_n_groups % world_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
self.mamba_n_groups, world_size
)
n_groups += extra_groups
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
conv_state_shape = (
divide(conv_dim, world_size),
self.mamba_d_conv - 1,
)
# we TP-ize on the heads dimension
temporal_state_shape = (
self.mamba_d_state,
self.mamba_d_head,
divide(self.mamba_n_heads, world_size),
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
def mamba2_cache_params(self):
shape = Mamba2StateShape.create(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.mamba_intermediate,
n_groups=self.mamba_n_groups,
num_heads=self.mamba_n_heads,
head_dim=self.mamba_d_head,
state_size=self.mamba_d_state,
conv_kernel=self.mamba_d_conv,
)
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)

View File

@@ -0,0 +1,117 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
import os
from dataclasses import dataclass, field
import numpy as np
import torch
from sglang.srt.distributed.utils import divide
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@dataclass(kw_only=True, frozen=True)
class Mamba2StateShape:
conv: tuple[int, int]
temporal: tuple[int, int, int]
intermediate_size: int
conv_dim: int
ssm_state_size: int
num_heads: int
head_dim: int
state_size: int
conv_kernel: int
@staticmethod
def create(
*,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> "Mamba2StateShape":
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
if n_groups % tp_world_size != 0:
extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
n_groups += extra_groups
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return Mamba2StateShape(
conv=conv_state_shape,
temporal=temporal_state_shape,
intermediate_size=intermediate_size,
conv_dim=conv_dim,
ssm_state_size=state_size,
num_heads=num_heads,
head_dim=head_dim,
state_size=state_size,
conv_kernel=conv_kernel,
)
@dataclass(kw_only=True, frozen=True)
class Mamba2StateDType:
conv: torch.dtype
temporal: torch.dtype
CONV_DTYPE = torch.bfloat16
def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
@dataclass(kw_only=True, frozen=True)
class Mamba2CacheParams:
shape: Mamba2StateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)

View File

@@ -0,0 +1,286 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
"""NemotronH model configuration"""
import regex as re
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
MAMBA = "M"
ATTENTION = "*"
MLP = "-"
class NemotronHConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`NemotronHModel`]. It is used to instantiate a NemotronH model according
to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to
that of the NemotronH-v0.1 model.
Args:
vocab_size (`int`, *optional*, defaults to 131072):
Vocabulary size of the NemotronH model. Defines the number of
different tokens that can be represented by the `inputs_ids`
passed when calling [`NemotronHModel`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be
tied. Note that this is only relevant if the model has an output
word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 21504):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 52):
Number of hidden layers in the Transformer encoder.
hybrid_override_pattern (`str`, *optional*, defaults to
`"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
The pattern of the hybrid model. The pattern is a string of
characters where each character represents
M: Mamba2, *: Attention, -: MLP
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer encoder.
attention_head_dim (`int`, *optional*, defaults to 128):
Dimension of each attention head.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use
Multi Head Attention (MHA), if `num_key_value_heads=1` the model
will use Multi Query Attention (MQA) otherwise GQA is used.
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
The non-linear activation function in the MLP layers.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in attention layers.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in MLP layers.
use_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
Whether or not residuals should be in `float32`. If set to `False`
residuals will keep the same `dtype` as the rest of the model.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`,
all logits will be calculated. If an integer value, only last
`num_logits_to_keep` logits will be calculated.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
sliding_window (`int`, *optional*, defaults to None):
Sliding window attention window size.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used
with.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden states.
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use the fast mamba kernels.
These are available only if `mamba-ssm` and `causal-conv1d`
are installed, and the mamba modules are running on a CUDA device.
ssm_state_size (`int`, *optional*, defaults to 128):
The dimension of the mamba state space latents.
mamba_num_heads (`int`, *optional*, defaults to 128):
Number of heads in Mamba layers.
mamba_n_groups (`int`, *optional*, defaults to 8):
Number of groups in Mamba layers.
mamba_head_dim (`int`, *optional*, defaults to 64):
Dimension of each Mamba head.
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor used to determine the mamba intermediate size.
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
The non-linear activation function in the Mamba layers.
mamba_dt_min (`float`, *optional*, defaults to 0.001):
Minimum value for the time step in Mamba.
mamba_dt_max (`float`, *optional*, defaults to 0.1):
Maximum value for the time step in Mamba.
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
Limits for the time step in Mamba.
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
Floor value for time step initialization in Mamba.
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the convolution layer of the mamba mixer
block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the input and output projections of the
mamba mixer block.
mamba_chunk_size (`int`, *optional*, defaults to 256):
Size of chunks for Mamba processing.
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
Whether to rescale the pre-normalization residual connections.
"""
model_type = "nemotron_h"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=21504,
num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
num_attention_heads=32,
head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups
mlp_hidden_act="relu2",
attention_bias=False,
mlp_bias=False,
use_bias=False,
initializer_range=0.02, # nemo: init_method_std
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
residual_in_fp32=False, # Megatron Core default value
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
sliding_window=None,
max_position_embeddings=4096,
attention_dropout=0.0,
hidden_dropout=0.0, # * ADDED
use_mamba_kernels=True,
ssm_state_size=128, # mamba_state_size
mamba_num_heads=128,
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
mamba_head_dim=64,
mamba_d_conv=4,
mamba_expand=2,
mamba_hidden_act="silu",
mamba_dt_min=0.001,
mamba_dt_max=0.1,
mamba_dt_limit=(0.0, float("inf")),
mamba_dt_init_floor=1e-4,
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_chunk_size=256,
rescale_prenorm_residual=True,
**kwargs,
):
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
# Validate hybrid_override_pattern
# M: Mamba2, *: Attention, -: MLP
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
"hybrid_override_pattern must have same length as " "num_hidden_layers"
)
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
"hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
)
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_hidden_act = mlp_hidden_act
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias
self.use_bias = use_bias
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.use_mamba_kernels = use_mamba_kernels
self.mamba_n_groups = mamba_n_groups
self.mamba_head_dim = mamba_head_dim
self.ssm_state_size = ssm_state_size
self.mamba_num_heads = mamba_num_heads
self.conv_kernel = mamba_d_conv
self.expand = mamba_expand
self.mamba_hidden_act = mamba_hidden_act
self.time_step_min = mamba_dt_min
self.time_step_max = mamba_dt_max
self.time_step_limit = mamba_dt_limit
self.time_step_floor = mamba_dt_init_floor
self.use_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_chunk_size = mamba_chunk_size
self.rescale_prenorm_residual = rescale_prenorm_residual
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def mamba_layer_ids(self):
return [
i
for i in range(self.num_hidden_layers)
if self.hybrid_override_pattern[i] == MAMBA
]
@property
def full_attention_layer_ids(self):
return [
i
for i in range(self.num_hidden_layers)
if self.hybrid_override_pattern[i] == ATTENTION
]
@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
n_groups=self.n_groups,
num_heads=self.mamba_num_heads,
head_dim=self.mamba_head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel,
)
return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)

View File

@@ -15,14 +15,12 @@
"""Qwen3Hybrid model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
]
@property
def hybrid_gdn_params(self):
world_size = get_attention_tp_size()
conv_dim = (
self.linear_key_head_dim * self.linear_num_key_heads * 2
+ self.linear_value_head_dim * self.linear_num_value_heads
)
conv_state_shape = (
divide(conv_dim, world_size),
self.linear_conv_kernel_dim - 1,
def mamba2_cache_params(self) -> Mamba2CacheParams:
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
n_groups=self.linear_num_key_heads,
num_heads=self.linear_num_value_heads,
head_dim=self.linear_value_head_dim,
state_size=self.linear_key_head_dim,
conv_kernel=self.linear_conv_kernel_dim,
)
temporal_state_shape = (
divide(self.linear_num_value_heads, world_size),
self.linear_key_head_dim,
self.linear_value_head_dim,
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
)
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)

View File

@@ -1,7 +1,14 @@
import logging
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
# evade circular imports
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.model_runner import ModelRunner
ATTENTION_BACKENDS = {}
@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend):
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
"""
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert not (
runner.is_hybrid_gdn and runner.use_mla_backend
runner.hybrid_gdn_config is not None and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models
if runner.is_hybrid_gdn:
if cfg := runner.mambaish_config:
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend,
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
if runner.hybrid_gdn_config is not None:
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
)
full_attn_layers = cfg.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)

View File

@@ -181,6 +181,45 @@ def _layer_norm_fwd(
return out, mean, rstd
def rms_norm_gated(
*,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
class LayerNormFn(torch.autograd.Function):
@staticmethod
@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
return rms_norm_gated(
x=x,
weight=weight,
bias=bias,
eps=eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
def layernorm_fn(
@@ -238,14 +261,6 @@ def layernorm_fn(
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNorm(torch.nn.Module):
def __init__(
@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
is_rms_norm=False,
)
@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
return layernorm_fn(
x,
self.weight,
self.bias,
@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
is_rms_norm=True,
)

View File

@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
ForwardMetadata,
Mamba2Metadata,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu
@@ -47,18 +54,10 @@ elif is_npu():
causal_conv1d_update = causal_conv1d_update_npu
@dataclass
class ForwardMetadata:
query_start_loc: Optional[torch.Tensor]
mamba_cache_indices: torch.Tensor
class MambaAttnBackend(AttentionBackend):
"""Attention backend using Mamba kernel."""
class MambaAttnBackendBase(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.pad_slot_id = -1 # Default pad slot id
self.pad_slot_id = PAD_SLOT_ID
self.device = model_runner.device
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
self.forward_metadata: ForwardMetadata = None
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
def _forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
forward_batch.req_pool_indices
)
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
self.forward_metadata = self._forward_metadata(forward_batch)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.forward_metadata = self._capture_metadata(
bs, req_pool_indices, forward_mode
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.forward_metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
assert (
max_num_tokens % max_bs == 0
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
def _capture_metadata(
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
):
if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}")
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
def init_forward_metadata_replay_cuda_graph(
def _replay_metadata(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1 # Mamba attn does not use seq lens to index kv cache
class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
dt_bias = kwargs["dt_bias"]
layer_id = kwargs["layer_id"]
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = mamba_cache_params.conv
ssm_states = mamba_cache_params.temporal
if is_target_verify:
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = self.req_to_token_pool.get_mamba_params(layer_id)
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
intermediate_state_cache = mamba_cache_params.intermediate_ssm
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool,
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
)
conv_states_to_use = conv_states.clone()
else:
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
has_initial_states = forward_batch.extend_prefix_lens > 0
conv_states_to_use = conv_states
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
return core_attn_out
class Mamba2AttnBackend(MambaAttnBackendBase):
"""Attention backend wrapper for Mamba2Mixer kernels."""
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
config = model_runner.mamba2_config
assert config is not None
self.mamba_chunk_size = config.mamba_chunk_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata = self._forward_metadata(forward_batch)
self.forward_metadata = Mamba2Metadata.prepare_mixed(
metadata.query_start_loc,
metadata.mamba_cache_indices,
self.mamba_chunk_size,
forward_batch,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def forward(
self,
mixer: MambaMixer2,
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_id: int,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
assert isinstance(self.forward_metadata, Mamba2Metadata)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
return mixer.forward(
hidden_states=hidden_states,
output=output,
layer_cache=layer_cache,
metadata=self.forward_metadata,
mup_vector=mup_vector,
use_triton_causal_conv=use_triton_causal_conv,
)
def forward_decode(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
def forward_extend(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
class HybridLinearAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
"""Manages a full and linear attention backend"""
def __init__(
self,
full_attn_backend: AttentionBackend,
linear_attn_backend: AttentionBackend,
linear_attn_backend: MambaAttnBackendBase,
full_attn_layers: list[int],
):
self.full_attn_layers = full_attn_layers
self.full_attn_backend = full_attn_backend
self.linear_attn_backend = linear_attn_backend
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
def init_forward_metadata(self, forward_batch: ForwardBatch):
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
)
def get_cuda_graph_seq_len_fill_value(self):
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
def forward_decode(
self,
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
):
layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_decode(
return self.full_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
return self.attn_backend_list[1].forward_decode(
return self.linear_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
):
layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_extend(
return self.full_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
return self.attn_backend_list[1].forward_extend(
return self.linear_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0]
state_indices_tensor = self.attn_backend_list[
1
].forward_metadata.mamba_cache_indices[:request_number]
state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
:request_number
]
)
mamba_caches = self.attn_backend_list[
1
].req_to_token_pool.get_mamba_params_all_layers()
mamba_caches = (
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
)
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = mamba_caches
conv_states = mamba_caches.conv
ssm_states = mamba_caches.temporal
intermediate_state_cache = mamba_caches.intermediate_ssm
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
# SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0

View File

@@ -10,7 +10,7 @@ import torch
from sgl_kernel import causal_conv1d_fwd
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
PAD_SLOT_ID = -1
from .causal_conv1d_triton import PAD_SLOT_ID
def causal_conv1d_fn(

View File

@@ -6,11 +6,11 @@ from typing import List, Optional, Union
import numpy as np
import torch
PAD_SLOT_ID = -1
import triton
import triton.language as tl
PAD_SLOT_ID = -1
@triton.jit()
def _causal_conv1d_fwd_kernel( # continuous batching
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
+ (conv_state_batch_coord * stride_conv_state_seq)
+ conv_state_token_offset * stride_conv_state_tok
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
:, None
]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_state_batch_coord < num_cache_lines)
@@ -897,7 +899,10 @@ def causal_conv1d_update(
stride_state_indices = (
conv_state_indices.stride(0) if conv_state_indices is not None else 0
)
state_len = width - 1 + (seqlen - 1) # effective state_len needed
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):

View File

@@ -1,23 +1,30 @@
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.custom_op import CustomOp
from sglang.srt.configs.mamba_utils import (
Mamba2CacheParams,
extra_groups_for_head_shards,
)
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn as causal_conv1d_fn_triton,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_update as causal_conv1d_update_triton,
)
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
from sglang.srt.layers.attention.mamba.ops import (
mamba_chunk_scan_combined,
selective_state_update,
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.mem_cache.memory_pool import MambaPool
from sglang.srt.model_loader.weight_utils import (
composed_weight_loader,
sharded_weight_loader,
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
return loader
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return layernorm_fn(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
class MambaMixer2(torch.nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
def __init__(
self,
cache_params: Mamba2CacheParams,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
chunk_size: int,
layer_id: int,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
# cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.num_heads = num_heads = cache_params.shape.num_heads
self.head_dim = cache_params.shape.head_dim
assert (
num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
"then num_groups must equal 1."
)
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = cache_params.shape.ssm_state_size
self.activation = activation
self.layer_id = layer_id
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
self.chunk_size = chunk_size
conv_kernel_size = cache_params.shape.conv_kernel
self.intermediate_size = intermediate_size = (
cache_params.shape.intermediate_size
)
self.n_groups = n_groups
if n_groups % self.tp_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size
)
groups = extra_groups_for_head_shards(n_groups, self.tp_size)
self.n_groups = n_groups + groups
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
self.conv_dim = cache_params.shape.conv_dim
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
if n_groups % self.tp_size == 0:
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
prefix=f"{prefix}.in_proj",
)
if n_groups % self.tp_size != 0:
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
):
pass
def forward(
self,
*,
hidden_states: torch.Tensor,
output: torch.Tensor,
forward_batch: ForwardBatch,
layer_cache: MambaPool.State,
metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
# attn_backend_list[-1] gives access to MambaAttnBackend
mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
attn_metadata = mamba_backend.forward_metadata
state_indices_tensor = attn_metadata.mamba_cache_indices
chunk_size = self.chunk_size
# metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
state_indices_tensor = metadata.mamba_cache_indices
conv_state = layer_cache.conv
ssm_state = layer_cache.temporal
conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
self.layer_id
)
assert (
ssm_state.size(1) == self.ssm_state_size
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
query_start_loc = attn_metadata.query_start_loc
chunk_size = self.chunk_size
# TODO: properly support this
prep_initial_states = False
query_start_loc = metadata.query_start_loc
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
dim=-1,
)
num_prefills = metadata.num_prefills # request count
num_decodes = metadata.num_decodes # token count (=request)
num_prefill_tokens = metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
assert num_actual_tokens == projected_states.shape[0]
# NOTE: V0 put prefill before decode
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
projected_states.shape[0],
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if forward_batch.forward_mode.is_extend():
if has_prefill:
mixed_metadata = metadata.mixed_metadata
assert mixed_metadata is not None
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
num_prefill_tokens = forward_batch.extend_num_tokens or 0
has_initial_states = forward_batch.extend_prefix_lens > 0
cache_indices = attn_metadata.mamba_cache_indices
x = hidden_states_B_C.transpose(
has_initial_states_p = mixed_metadata.has_initial_states
prep_initial_states = mixed_metadata.prep_initial_states
cache_indices = state_indices_tensor_p
x = hidden_states_B_C_p.transpose(
0, 1
) # this is the form that causal-conv see
hidden_states_B_C = causal_conv1d_fn(
ccfn = (
causal_conv1d_fn
if not use_triton_causal_conv
else causal_conv1d_fn_triton
)
hidden_states_B_C_p = ccfn(
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_states,
has_initial_state=has_initial_states_p,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
query_start_loc=query_start_loc_p,
seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if has_initial_states is not None and prep_initial_states:
if has_initial_states_p is not None and prep_initial_states:
initial_states = torch.where(
has_initial_states[:, None, None, None],
ssm_state[state_indices_tensor],
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p],
0,
)
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states.view(
hidden_states_p.view(
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
),
dt.unsqueeze(0),
dt_p.unsqueeze(0),
self.A,
B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=mixed_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
cu_seqlens=query_start_loc,
seq_idx=mixed_metadata.seq_idx,
chunk_indices=mixed_metadata.chunk_indices,
chunk_offsets=mixed_metadata.chunk_offsets,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
out=preallocated_ssm_out_p.view(
1, num_prefill_tokens, -1, self.head_dim
),
state_dtype=ssm_state.dtype,
)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
elif forward_batch.forward_mode.is_decode():
num_decodes = len(query_start_loc) - 1
ssm_state[state_indices_tensor_p] = varlen_state
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
ccu = (
causal_conv1d_update
if not use_triton_causal_conv
else causal_conv1d_update_triton
)
hidden_states_B_C_d = ccu(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor,
conv_state_indices=state_indices_tensor_d,
)
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A = (
A_d = (
self.A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states = hidden_states.view(
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# - layer_state.ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state.permute(0, 3, 2, 1),
hidden_states,
dt,
A,
B,
C,
D,
ssm_state,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
elif forward_batch.forward_mode.is_idle():
preallocated_ssm_out = preallocated_ssm_out
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out, gate)
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:], _ = self.out_proj(hidden_states)
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
@property
def mamba_type(self) -> str:

View File

@@ -0,0 +1,211 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
import math
from dataclasses import dataclass
import torch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@dataclass(kw_only=True)
class ForwardMetadata:
query_start_loc: torch.Tensor
mamba_cache_indices: torch.Tensor
@dataclass(kw_only=True)
class Mamba2Metadata(ForwardMetadata):
"""stable metadata across all mamba2 layers in the forward pass"""
num_prefills: int
num_prefill_tokens: int
num_decodes: int
@dataclass(kw_only=True, frozen=True)
class MixedMetadata:
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
extend_seq_lens_cpu: list[int]
mixed_metadata: MixedMetadata | None = None
"""`mixed_metadata` is used for extend/mixed requests"""
@staticmethod
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = (
math.ceil(total_seqlens / chunk_size)
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
)
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
chunk_offsets = torch.zeros(
(N,), dtype=torch.int, device=query_start_loc.device
)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += s % chunk_size > 0
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
@staticmethod
def prepare_decode(
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
seq_lens: torch.Tensor,
) -> "Mamba2Metadata":
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_decodes=len(seq_lens),
num_prefills=0,
num_prefill_tokens=0,
)
@classmethod
def prepare_mixed(
cls,
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
chunk_size: int,
forward_batch: ForwardBatch,
) -> "Mamba2Metadata":
"""This path cannot run with CUDA graph, as it contains extend requests."""
if forward_batch.extend_num_tokens is None:
return cls.prepare_decode(
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
)
num_prefills = len(forward_batch.extend_seq_lens)
num_prefill_tokens = forward_batch.extend_num_tokens
num_decodes = len(forward_batch.seq_lens) - num_prefills
context_lens_tensor = forward_batch.extend_prefix_lens
assert context_lens_tensor is not None
# precompute flag to avoid device syncs later
has_initial_states = context_lens_tensor > 0
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
query_start_loc = query_start_loc[: num_prefills + 1]
seq_idx = torch.repeat_interleave(
torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device
),
query_start_loc.diff(),
output_size=num_prefill_tokens,
)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
chunk_offsets, chunk_indices = None, None
if prep_initial_states:
chunk_indices, chunk_offsets = (
cls._query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens
)
)
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
mixed_metadata=cls.MixedMetadata(
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
),
)

View File

@@ -1,81 +0,0 @@
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
from sglang.srt.distributed.utils import divide
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups

View File

@@ -0,0 +1,120 @@
from typing import Union
import torch
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
from sglang.srt.utils.common import set_weight_attrs
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * torch.nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(
x=x,
weight=self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
is_rms_norm=True,
)

View File

@@ -15,56 +15,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "K", "IS_CAUSAL"],
# )
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices

View File

@@ -16,66 +16,6 @@ from packaging import version
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
# )
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices

View File

@@ -17,17 +17,6 @@ import triton.language as tl
from .mamba_ssm import softplus
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE_H": 2}),
# triton.Config({"BLOCK_SIZE_H": 4}),
# triton.Config({"BLOCK_SIZE_H": 8}),
# triton.Config({"BLOCK_SIZE_H": 16}),
# triton.Config({"BLOCK_SIZE_H": 32}),
# triton.Config({"BLOCK_SIZE_H": 64}),
# ],
# key=["chunk_size", "nheads"],
# )
@triton.jit
def _chunk_cumsum_fwd_kernel(
# Pointers to matrices
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit
def _chunk_state_fwd_kernel(
# Pointers to matrices
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
tl.store(states_ptrs, states, mask=c_mask)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices

View File

@@ -13,17 +13,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE": 64}),
# triton.Config({"BLOCK_SIZE": 128}),
# triton.Config({"BLOCK_SIZE": 256}),
# triton.Config({"BLOCK_SIZE": 512}),
# triton.Config({"BLOCK_SIZE": 1024}),
# triton.Config({"BLOCK_SIZE": 2048}),
# ],
# key=["dim"],
# )
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices

View File

@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
if model_runner.is_hybrid_gdn:
if model_runner.hybrid_gdn_config is not None:
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:

View File

@@ -1770,7 +1770,7 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
if self.tp_worker.worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)

View File

@@ -15,6 +15,9 @@ limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -109,17 +112,38 @@ class ReqToTokenPool:
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_conv_window: torch.Tensor
def __init__(
self,
*,
size: int,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
num_mamba_layers: int,
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
cache_params: "Mamba2CacheParams",
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
@@ -158,11 +182,11 @@ class MambaPool:
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
intermediate_ssm_state_cache,
intermediate_conv_window_cache,
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
@@ -172,7 +196,7 @@ class MambaPool:
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
self.mamba_cache = (conv_state, temporal_state)
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
@@ -180,16 +204,14 @@ class MambaPool:
)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_mamba_params_all_layers(self):
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
assert isinstance(self.mamba_cache, self.SpeculativeState)
return self.mamba_cache
def get_mamba_params(self, layer_id: int):
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
def get_mamba_size(self):
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
def mamba2_layer_cache(self, layer_id: int):
return self.mamba_cache.at_layer_idx(layer_id)
def available_size(self):
return len(self.free_slots)
@@ -208,7 +230,9 @@ class MambaPool:
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
def clear(self):
self.free_slots = list(range(self.size))
@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
def __init__(
self,
*,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
mamba_layers: List[int],
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
speculative_num_draft_tokens: int,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int = None,
):
super().__init__(
size=size,
@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
)
self.mamba_pool = MambaPool(
size,
conv_dtype,
ssm_dtype,
len(mamba_layers),
conv_state_shape,
temporal_state_shape,
device,
speculative_num_draft_tokens,
size=size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices]
def get_mamba_params(self, layer_id: int):
def mamba2_layer_cache(self, layer_id: int):
assert layer_id in self.mamba_map
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
def get_mamba_params_all_layers(self):
return self.mamba_pool.get_mamba_params_all_layers()
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):

View File

@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import (
@@ -354,8 +355,9 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
if config := self.mambaish_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
@@ -364,6 +366,7 @@ class ModelRunner:
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
@@ -1267,8 +1270,8 @@ class ModelRunner:
"num_nextn_predict_layers",
self.num_effective_layers,
)
elif self.is_hybrid_gdn:
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
elif config := self.mambaish_config:
num_layers = len(config.full_attention_layer_ids)
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
@@ -1288,22 +1291,32 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if self.is_hybrid_gdn:
if config := self.mambaish_config:
rest_memory -= (
self.server_args.max_mamba_cache_size
* self.model_config.hf_config.mamba_cache_per_req
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
@property
def is_hybrid_gdn(self):
return self.model_config.hf_config.architectures[0] in [
"Qwen3NextForCausalLM",
"Qwen3NextForCausalLMMTP",
"FalconH1ForCausalLM",
]
def hybrid_gdn_config(self):
config = self.model_config.hf_config
if isinstance(config, Qwen3NextConfig):
return config
return None
@property
def mamba2_config(self):
config = self.model_config.hf_config
if isinstance(config, FalconH1Config | NemotronHConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
def set_num_token_hybrid(self):
if (
@@ -1438,7 +1451,7 @@ class ModelRunner:
),
4096,
)
if self.is_hybrid_gdn:
if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
@@ -1519,26 +1532,14 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif self.is_hybrid_gdn:
config = self.model_config.hf_config
(
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) = config.hybrid_gdn_params
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
conv_state_shape=conv_state_shape,
temporal_state_shape=temporal_state_shape,
conv_dtype=conv_dtype,
ssm_dtype=ssm_dtype,
mamba_layers=mamba_layers,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
else:
@@ -1640,7 +1641,7 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
)
elif self.is_hybrid_gdn:
elif config := self.mambaish_config:
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size,
size=self.max_total_num_tokens,
@@ -1651,9 +1652,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
# if draft worker, we only need 1 attention layer's kv pool
full_attention_layer_ids=(
[0]
if self.is_draft_worker
else self.model_config.hf_config.full_attention_layer_ids
[0] if self.is_draft_worker else config.full_attention_layer_ids
),
enable_kvcache_transpose=False,
device=self.device,
@@ -1681,7 +1680,8 @@ class ModelRunner:
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
self.server_args.attention_backend == "ascend"
or self.hybrid_gdn_config is not None
):
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,

View File

@@ -8,6 +8,10 @@ from torch import nn
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
)
self.mamba = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=self.d_ssm,
use_conv_bias=config.mamba_conv_bias,
use_bias=config.mamba_proj_bias,
n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
layer_id=layer_id,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
chunk_size=config.mamba_chunk_size,
activation=config.hidden_act,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
# Mamba block
mamba_hidden_states = torch.empty_like(hidden_states)
self.mamba(
attn_backend.linear_attn_backend.forward(
self.mamba,
hidden_states * self.ssm_in_multiplier,
mamba_hidden_states,
forward_batch=forward_batch,
layer_id=self.layer_id,
mup_vector=self.mup_vector,
)
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier

View File

@@ -0,0 +1,514 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
"""Inference-only NemotronH model."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from sglang.srt.configs import NemotronHConfig
from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import ReLU2
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers_non_pp
from sglang.utils import logger
class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLU2()
def forward(self, x: torch.Tensor):
x, _ = self.up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class NemotronHMLPDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = NemotronHMLP(
config,
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(hidden_states)
return hidden_states, residual
class NemotronHMambaDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_idx
self.mixer = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.mamba_n_groups,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
attn_backend.linear_attn_backend.forward(
mixer=self.mixer,
layer_id=self.layer_id,
hidden_states=hidden_states,
output=output,
use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
)
return output, residual
class NemotronHAttention(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn.forward(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class NemotronHAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.mixer = NemotronHAttention(
config,
layer_idx,
quant_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(
hidden_states=hidden_states, forward_batch=forward_batch
)
return hidden_states, residual
Layers = (
NemotronHAttentionDecoderLayer
| NemotronHMLPDecoderLayer
| NemotronHMambaDecoderLayer
)
ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
ATTENTION: NemotronHAttentionDecoderLayer,
MLP: NemotronHMLPDecoderLayer,
MAMBA: NemotronHMambaDecoderLayer,
}
class NemotronHModel(nn.Module):
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(idx: int, prefix: str):
layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
self.layers = make_layers_non_pp(
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
)
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
residual = None
for layer in self.layers:
if not isinstance(layer, Layers):
raise ValueError(f"Unknown layer type: {type(layer)}")
hidden_states, residual = layer.forward(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
if not get_pp_group().is_last_rank:
return PPProxyTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class NemotronHForCausalLM(nn.Module):
remap_prefix = {"backbone": "model"}
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
self.model = self._init_model(
config=config, quant_config=quant_config, prefix=prefix
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
def _init_model(
self,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
hidden_states = self.model.forward(
input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
updated_weights = []
for name, loaded_weight in weights:
for prefix, new_key in self.remap_prefix.items():
if name.startswith(prefix):
name = name.replace(prefix, new_key)
for substr, new_key in self.remap_substr.items():
if substr in name:
name = name.replace(substr, new_key)
updated_weights.append((name, loaded_weight))
params_dict = dict(self.named_parameters())
for name, loaded_weight in updated_weights:
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [NemotronHForCausalLM]

View File

@@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker):
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# QQ: can be optimized
if self.target_worker.model_runner.is_hybrid_gdn:
if self.target_worker.model_runner.hybrid_gdn_config is not None:
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = (
torch.tensor(

View File

@@ -518,6 +518,24 @@ def make_layers(
return modules, start_layer, end_layer
def make_layers_non_pp(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> torch.nn.ModuleList:
from sglang.srt.offloader import get_offloader
layers = torch.nn.ModuleList(
get_offloader().wrap_modules(
(
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
for idx in range(num_hidden_layers)
)
)
)
return layers
cmo_stream = None

View File

@@ -45,6 +45,7 @@ from sglang.srt.configs import (
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
NemotronHConfig,
Qwen3NextConfig,
Step3VLConfig,
)
@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig,
NemotronHConfig.model_type: NemotronHConfig,
}
for name, cls in _CONFIG_REGISTRY.items():