model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)

Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
Netanel Haber
2025-10-08 19:37:38 +03:00
committed by GitHub
parent c882b5ae75
commit d6837aea4d
35 changed files with 3280 additions and 854 deletions

View File

@@ -53,6 +53,7 @@ in the GitHub search bar.
| **Ling** (16.8B290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAIs open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. |
| **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. |
| **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBMs Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. |
| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. |
| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. |
| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. |
| **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. |
| **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). |

View File

@@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
from sglang.srt.configs.nemotron_h import NemotronHConfig
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
from sglang.srt.configs.step3_vl import (
Step3TextConfig,
@@ -32,4 +33,5 @@ __all__ = [
"DotsVLMConfig",
"DotsOCRConfig",
"FalconH1Config",
"NemotronHConfig",
]

View File

@@ -15,16 +15,12 @@
"""Falcon-H1 model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
get_tensor_model_parallel_world_size,
@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig):
self.rope_scaling = None
self.rope_scaling = rope_scaling
self.projectors_bias = projectors_bias
mamba_intermediate = (
self.mamba_intermediate = mamba_intermediate = (
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
)
@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig):
def layers_block_type(self):
return ["falcon_h1" for i in range(self.num_hidden_layers)]
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
@property
def full_attention_layer_ids(self):
# For Falcon-H1, we do have attention on all layers
@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig):
return range(self.num_hidden_layers)
@property
def hybrid_gdn_params(self):
world_size = get_tensor_model_parallel_world_size()
n_groups = self.mamba_n_groups
if self.mamba_n_groups % world_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
self.mamba_n_groups, world_size
)
n_groups += extra_groups
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
conv_state_shape = (
divide(conv_dim, world_size),
self.mamba_d_conv - 1,
)
# we TP-ize on the heads dimension
temporal_state_shape = (
self.mamba_d_state,
self.mamba_d_head,
divide(self.mamba_n_heads, world_size),
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
def mamba2_cache_params(self):
shape = Mamba2StateShape.create(
tp_world_size=get_tensor_model_parallel_world_size(),
intermediate_size=self.mamba_intermediate,
n_groups=self.mamba_n_groups,
num_heads=self.mamba_n_heads,
head_dim=self.mamba_d_head,
state_size=self.mamba_d_state,
conv_kernel=self.mamba_d_conv,
)
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)

View File

@@ -0,0 +1,117 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
import os
from dataclasses import dataclass, field
import numpy as np
import torch
from sglang.srt.distributed.utils import divide
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups
@dataclass(kw_only=True, frozen=True)
class Mamba2StateShape:
conv: tuple[int, int]
temporal: tuple[int, int, int]
intermediate_size: int
conv_dim: int
ssm_state_size: int
num_heads: int
head_dim: int
state_size: int
conv_kernel: int
@staticmethod
def create(
*,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> "Mamba2StateShape":
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
if n_groups % tp_world_size != 0:
extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
n_groups += extra_groups
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return Mamba2StateShape(
conv=conv_state_shape,
temporal=temporal_state_shape,
intermediate_size=intermediate_size,
conv_dim=conv_dim,
ssm_state_size=state_size,
num_heads=num_heads,
head_dim=head_dim,
state_size=state_size,
conv_kernel=conv_kernel,
)
@dataclass(kw_only=True, frozen=True)
class Mamba2StateDType:
conv: torch.dtype
temporal: torch.dtype
CONV_DTYPE = torch.bfloat16
def mamba2_state_dtype() -> Mamba2StateDType:
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
@dataclass(kw_only=True, frozen=True)
class Mamba2CacheParams:
shape: Mamba2StateShape
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
layers: list[int]
@property
def mamba_cache_per_req(self) -> int:
return (
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
) * len(self.layers)

View File

@@ -0,0 +1,286 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
"""NemotronH model configuration"""
import regex as re
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.layers.dp_attention import get_attention_tp_size
logger = logging.get_logger(__name__)
MAMBA = "M"
ATTENTION = "*"
MLP = "-"
class NemotronHConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a
[`NemotronHModel`]. It is used to instantiate a NemotronH model according
to the specified arguments, defining the model architecture. Instantiating
a configuration with the defaults will yield a similar configuration to
that of the NemotronH-v0.1 model.
Args:
vocab_size (`int`, *optional*, defaults to 131072):
Vocabulary size of the NemotronH model. Defines the number of
different tokens that can be represented by the `inputs_ids`
passed when calling [`NemotronHModel`]
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be
tied. Note that this is only relevant if the model has an output
word embedding layer.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 21504):
Dimension of the MLP representations.
num_hidden_layers (`int`, *optional*, defaults to 52):
Number of hidden layers in the Transformer encoder.
hybrid_override_pattern (`str`, *optional*, defaults to
`"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
The pattern of the hybrid model. The pattern is a string of
characters where each character represents
M: Mamba2, *: Attention, -: MLP
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the
Transformer encoder.
attention_head_dim (`int`, *optional*, defaults to 128):
Dimension of each attention head.
num_key_value_heads (`int`, *optional*, defaults to 8):
This is the number of key_value heads that should be used to
implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use
Multi Head Attention (MHA), if `num_key_value_heads=1` the model
will use Multi Query Attention (MQA) otherwise GQA is used.
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
The non-linear activation function in the MLP layers.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in attention layers.
mlp_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in MLP layers.
use_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the model.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
The epsilon used by the layer normalization layers.
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
Whether or not residuals should be in `float32`. If set to `False`
residuals will keep the same `dtype` as the rest of the model.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values
attentions (not used by all models). Only relevant if
`config.is_decoder=True`.
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
Number of prompt logits to calculate during generation. If `None`,
all logits will be calculated. If an integer value, only last
`num_logits_to_keep` logits will be calculated.
pad_token_id (`int`, *optional*, defaults to 0):
The id of the padding token.
bos_token_id (`int`, *optional*, defaults to 1):
The id of the "beginning-of-sequence" token.
eos_token_id (`int`, *optional*, defaults to 2):
The id of the "end-of-sequence" token.
sliding_window (`int`, *optional*, defaults to None):
Sliding window attention window size.
max_position_embeddings (`int`, *optional*, defaults to 4096):
The maximum sequence length that this model might ever be used
with.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
hidden_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the hidden states.
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
Flag indicating whether or not to use the fast mamba kernels.
These are available only if `mamba-ssm` and `causal-conv1d`
are installed, and the mamba modules are running on a CUDA device.
ssm_state_size (`int`, *optional*, defaults to 128):
The dimension of the mamba state space latents.
mamba_num_heads (`int`, *optional*, defaults to 128):
Number of heads in Mamba layers.
mamba_n_groups (`int`, *optional*, defaults to 8):
Number of groups in Mamba layers.
mamba_head_dim (`int`, *optional*, defaults to 64):
Dimension of each Mamba head.
mamba_d_conv (`int`, *optional*, defaults to 4):
The size of the mamba convolution kernel.
mamba_expand (`int`, *optional*, defaults to 2):
Expanding factor used to determine the mamba intermediate size.
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
The non-linear activation function in the Mamba layers.
mamba_dt_min (`float`, *optional*, defaults to 0.001):
Minimum value for the time step in Mamba.
mamba_dt_max (`float`, *optional*, defaults to 0.1):
Maximum value for the time step in Mamba.
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
Limits for the time step in Mamba.
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
Floor value for time step initialization in Mamba.
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the convolution layer of the mamba mixer
block.
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
Whether to use bias in the input and output projections of the
mamba mixer block.
mamba_chunk_size (`int`, *optional*, defaults to 256):
Size of chunks for Mamba processing.
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
Whether to rescale the pre-normalization residual connections.
"""
model_type = "nemotron_h"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=131072,
tie_word_embeddings=False,
hidden_size=4096,
intermediate_size=21504,
num_hidden_layers=52,
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
num_attention_heads=32,
head_dim=128,
num_key_value_heads=8, # nemo: num_query_groups
mlp_hidden_act="relu2",
attention_bias=False,
mlp_bias=False,
use_bias=False,
initializer_range=0.02, # nemo: init_method_std
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
residual_in_fp32=False, # Megatron Core default value
use_cache=True,
num_logits_to_keep=1,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
sliding_window=None,
max_position_embeddings=4096,
attention_dropout=0.0,
hidden_dropout=0.0, # * ADDED
use_mamba_kernels=True,
ssm_state_size=128, # mamba_state_size
mamba_num_heads=128,
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
mamba_head_dim=64,
mamba_d_conv=4,
mamba_expand=2,
mamba_hidden_act="silu",
mamba_dt_min=0.001,
mamba_dt_max=0.1,
mamba_dt_limit=(0.0, float("inf")),
mamba_dt_init_floor=1e-4,
mamba_conv_bias=True,
mamba_proj_bias=False,
mamba_chunk_size=256,
rescale_prenorm_residual=True,
**kwargs,
):
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.hybrid_override_pattern = hybrid_override_pattern
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.sliding_window = sliding_window
self.max_position_embeddings = max_position_embeddings
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
# Validate hybrid_override_pattern
# M: Mamba2, *: Attention, -: MLP
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
"hybrid_override_pattern must have same length as " "num_hidden_layers"
)
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
"hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
)
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.mlp_hidden_act = mlp_hidden_act
self.attention_bias = attention_bias
self.mlp_bias = mlp_bias
self.use_bias = use_bias
self.initializer_range = initializer_range
self.layer_norm_epsilon = layer_norm_epsilon
self.residual_in_fp32 = residual_in_fp32
self.use_cache = use_cache
self.num_logits_to_keep = num_logits_to_keep
self.use_mamba_kernels = use_mamba_kernels
self.mamba_n_groups = mamba_n_groups
self.mamba_head_dim = mamba_head_dim
self.ssm_state_size = ssm_state_size
self.mamba_num_heads = mamba_num_heads
self.conv_kernel = mamba_d_conv
self.expand = mamba_expand
self.mamba_hidden_act = mamba_hidden_act
self.time_step_min = mamba_dt_min
self.time_step_max = mamba_dt_max
self.time_step_limit = mamba_dt_limit
self.time_step_floor = mamba_dt_init_floor
self.use_conv_bias = mamba_conv_bias
self.mamba_proj_bias = mamba_proj_bias
self.mamba_chunk_size = mamba_chunk_size
self.rescale_prenorm_residual = rescale_prenorm_residual
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
@property
def mamba_layer_ids(self):
return [
i
for i in range(self.num_hidden_layers)
if self.hybrid_override_pattern[i] == MAMBA
]
@property
def full_attention_layer_ids(self):
return [
i
for i in range(self.num_hidden_layers)
if self.hybrid_override_pattern[i] == ATTENTION
]
@property
def mamba2_cache_params(self) -> Mamba2CacheParams:
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
n_groups=self.n_groups,
num_heads=self.mamba_num_heads,
head_dim=self.mamba_head_dim,
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel,
)
return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)

View File

@@ -15,14 +15,12 @@
"""Qwen3Hybrid model configuration"""
import enum
import os
import numpy as np
import torch
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_rope_utils import rope_config_validation
from transformers.utils import logging
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.dp_attention import get_attention_tp_size
@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
]
@property
def hybrid_gdn_params(self):
world_size = get_attention_tp_size()
conv_dim = (
self.linear_key_head_dim * self.linear_num_key_heads * 2
+ self.linear_value_head_dim * self.linear_num_value_heads
)
conv_state_shape = (
divide(conv_dim, world_size),
self.linear_conv_kernel_dim - 1,
def mamba2_cache_params(self) -> Mamba2CacheParams:
shape = Mamba2StateShape.create(
tp_world_size=get_attention_tp_size(),
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
n_groups=self.linear_num_key_heads,
num_heads=self.linear_num_value_heads,
head_dim=self.linear_value_head_dim,
state_size=self.linear_key_head_dim,
conv_kernel=self.linear_conv_kernel_dim,
)
temporal_state_shape = (
divide(self.linear_num_value_heads, world_size),
self.linear_key_head_dim,
self.linear_value_head_dim,
)
conv_dtype = torch.bfloat16
dtype_map = {
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
mamba_layers = self.linear_layer_ids
return (
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
)
@property
def mamba_cache_per_req(self):
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
self.hybrid_gdn_params
)
mamba_layers_len = len(mamba_layers)
return (
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
) * mamba_layers_len
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)

View File

@@ -1,7 +1,14 @@
import logging
from typing import TYPE_CHECKING
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
# evade circular imports
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.model_executor.model_runner import ModelRunner
ATTENTION_BACKENDS = {}
@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
return DualChunkFlashAttentionBackend(runner)
def attn_backend_wrapper(runner, full_attn_backend):
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
"""
Wrapper for special models like hybrid GDN, so we don't
need to change the code of the original attention backend.
"""
assert not (
runner.is_hybrid_gdn and runner.use_mla_backend
runner.hybrid_gdn_config is not None and runner.use_mla_backend
), "hybrid_gdn can only be used with non-MLA models."
# wrap for hybrid GDN models
if runner.is_hybrid_gdn:
if cfg := runner.mambaish_config:
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
GDNAttnBackend,
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.utils import is_blackwell, is_npu
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
or runner.server_args.attention_backend == "trtllm_mha"
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
MambaAttnBackend,
)
linear_attn_backend = MambaAttnBackend(runner)
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
if runner.hybrid_gdn_config is not None:
if is_blackwell():
assert (
runner.server_args.attention_backend == "triton"
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
if is_npu():
assert (
runner.server_args.attention_backend == "ascend"
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
linear_attn_backend = GDNAttnBackend(runner)
elif runner.mamba2_config is not None:
linear_attn_backend = Mamba2AttnBackend(runner)
else:
raise ValueError(
"Expected hybrid GDN or NemotronH models, but got unknown model."
)
full_attn_layers = cfg.full_attention_layer_ids
return HybridLinearAttnBackend(
full_attn_backend, linear_attn_backend, full_attn_layers
)

View File

@@ -181,6 +181,45 @@ def _layer_norm_fwd(
return out, mean, rstd
def rms_norm_gated(
*,
x,
weight,
bias,
z=None,
eps=1e-6,
group_size=None,
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
class LayerNormFn(torch.autograd.Function):
@staticmethod
@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
norm_before_gate=True,
is_rms_norm=False,
):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if x.stride(-1) != 1:
x = x.contiguous()
if z is not None:
assert z.shape == x_shape_og
z = z.reshape(-1, z.shape[-1])
if z.stride(-1) != 1:
z = z.contiguous()
weight = weight.contiguous()
if bias is not None:
bias = bias.contiguous()
y, mean, rstd = _layer_norm_fwd(
x,
weight,
bias,
eps,
return rms_norm_gated(
x=x,
weight=weight,
bias=bias,
eps=eps,
z=z,
group_size=group_size,
norm_before_gate=norm_before_gate,
is_rms_norm=is_rms_norm,
)
return y.reshape(x_shape_og)
def layernorm_fn(
@@ -238,14 +261,6 @@ def layernorm_fn(
)
def rmsnorm_fn(
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
):
return LayerNormFn.apply(
x, weight, bias, z, eps, group_size, norm_before_gate, True
)
class LayerNorm(torch.nn.Module):
def __init__(
@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
group_size=self.group_size,
eps=self.eps,
norm_before_gate=self.norm_before_gate,
is_rms_norm=False,
)
@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
def forward(self, x, z=None):
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
return rmsnorm_fn(
return layernorm_fn(
x,
self.weight,
self.bias,
@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
eps=self.eps,
group_size=self.group_size,
norm_before_gate=self.norm_before_gate,
is_rms_norm=True,
)

View File

@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
fused_sigmoid_gating_delta_rule_update,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
ForwardMetadata,
Mamba2Metadata,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu
@@ -47,18 +54,10 @@ elif is_npu():
causal_conv1d_update = causal_conv1d_update_npu
@dataclass
class ForwardMetadata:
query_start_loc: Optional[torch.Tensor]
mamba_cache_indices: torch.Tensor
class MambaAttnBackend(AttentionBackend):
"""Attention backend using Mamba kernel."""
class MambaAttnBackendBase(AttentionBackend):
def __init__(self, model_runner: ModelRunner):
super().__init__()
self.pad_slot_id = -1 # Default pad slot id
self.pad_slot_id = PAD_SLOT_ID
self.device = model_runner.device
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
self.forward_metadata: ForwardMetadata = None
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
def init_forward_metadata(self, forward_batch: ForwardBatch):
def _forward_metadata(self, forward_batch: ForwardBatch):
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
forward_batch.req_pool_indices
)
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
self.forward_metadata = self._forward_metadata(forward_batch)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.forward_metadata = self._capture_metadata(
bs, req_pool_indices, forward_mode
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.forward_metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
assert (
max_num_tokens % max_bs == 0
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
def _capture_metadata(
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
):
if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
raise ValueError(f"Invalid forward mode: {forward_mode=}")
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
def init_forward_metadata_replay_cuda_graph(
def _replay_metadata(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
else:
raise ValueError(f"Invalid forward mode: {forward_mode=}")
self.forward_metadata = ForwardMetadata(
return ForwardMetadata(
query_start_loc=self.query_start_loc_list[bs - 1],
mamba_cache_indices=self.state_indices_list[bs - 1],
)
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
def get_cuda_graph_seq_len_fill_value(self):
return 1 # Mamba attn does not use seq lens to index kv cache
class GDNAttnBackend(MambaAttnBackendBase):
"""Attention backend using Mamba kernel."""
def forward_decode(
self,
q: torch.Tensor,
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
dt_bias = kwargs["dt_bias"]
layer_id = kwargs["layer_id"]
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = layer_cache.conv
ssm_states = layer_cache.temporal
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
conv_states = mamba_cache_params.conv
ssm_states = mamba_cache_params.temporal
if is_target_verify:
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = self.req_to_token_pool.get_mamba_params(layer_id)
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
intermediate_state_cache = mamba_cache_params.intermediate_ssm
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
has_initial_states = torch.ones(
seq_len // forward_batch.spec_info.draft_token_num,
dtype=torch.bool,
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
)
conv_states_to_use = conv_states.clone()
else:
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
layer_id
)
has_initial_states = forward_batch.extend_prefix_lens > 0
conv_states_to_use = conv_states
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
return core_attn_out
class Mamba2AttnBackend(MambaAttnBackendBase):
"""Attention backend wrapper for Mamba2Mixer kernels."""
def __init__(self, model_runner: ModelRunner):
super().__init__(model_runner)
config = model_runner.mamba2_config
assert config is not None
self.mamba_chunk_size = config.mamba_chunk_size
def init_forward_metadata(self, forward_batch: ForwardBatch):
metadata = self._forward_metadata(forward_batch)
self.forward_metadata = Mamba2Metadata.prepare_mixed(
metadata.query_start_loc,
metadata.mamba_cache_indices,
self.mamba_chunk_size,
forward_batch,
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self._replay_metadata(
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
)
self.forward_metadata = Mamba2Metadata.prepare_decode(
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
)
def forward(
self,
mixer: MambaMixer2,
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_id: int,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
assert isinstance(self.forward_metadata, Mamba2Metadata)
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
return mixer.forward(
hidden_states=hidden_states,
output=output,
layer_cache=layer_cache,
metadata=self.forward_metadata,
mup_vector=mup_vector,
use_triton_causal_conv=use_triton_causal_conv,
)
def forward_decode(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
def forward_extend(self, *args, **kwargs):
raise NotImplementedError(
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
)
class HybridLinearAttnBackend(AttentionBackend):
"""Support different backends for prefill and decode."""
"""Manages a full and linear attention backend"""
def __init__(
self,
full_attn_backend: AttentionBackend,
linear_attn_backend: AttentionBackend,
linear_attn_backend: MambaAttnBackendBase,
full_attn_layers: list[int],
):
self.full_attn_layers = full_attn_layers
self.full_attn_backend = full_attn_backend
self.linear_attn_backend = linear_attn_backend
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
def init_forward_metadata(self, forward_batch: ForwardBatch):
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
)
def get_cuda_graph_seq_len_fill_value(self):
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
def forward_decode(
self,
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
):
layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_decode(
return self.full_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
return self.attn_backend_list[1].forward_decode(
return self.linear_attn_backend.forward_decode(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
):
layer_id = layer.layer_id if layer else kwargs["layer_id"]
if layer_id in self.full_attn_layers:
return self.attn_backend_list[0].forward_extend(
return self.full_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
return self.attn_backend_list[1].forward_extend(
return self.linear_attn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
)
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
request_number = accepted_length.shape[0]
state_indices_tensor = self.attn_backend_list[
1
].forward_metadata.mamba_cache_indices[:request_number]
state_indices_tensor = (
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
:request_number
]
)
mamba_caches = self.attn_backend_list[
1
].req_to_token_pool.get_mamba_params_all_layers()
mamba_caches = (
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
)
(
conv_states,
ssm_states,
intermediate_state_cache,
intermediate_conv_window_cache,
) = mamba_caches
conv_states = mamba_caches.conv
ssm_states = mamba_caches.temporal
intermediate_state_cache = mamba_caches.intermediate_ssm
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
# SSM state updates (chunked to reduce peak memory)
valid_mask = accepted_length > 0

View File

@@ -10,7 +10,7 @@ import torch
from sgl_kernel import causal_conv1d_fwd
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
PAD_SLOT_ID = -1
from .causal_conv1d_triton import PAD_SLOT_ID
def causal_conv1d_fn(

View File

@@ -6,11 +6,11 @@ from typing import List, Optional, Union
import numpy as np
import torch
PAD_SLOT_ID = -1
import triton
import triton.language as tl
PAD_SLOT_ID = -1
@triton.jit()
def _causal_conv1d_fwd_kernel( # continuous batching
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
+ (conv_state_batch_coord * stride_conv_state_seq)
+ conv_state_token_offset * stride_conv_state_tok
+ (idx_feats * stride_conv_state_dim)[None, :]
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
:, None
]
) # [BLOCK_M, BLOCK_N]
mask = (
(conv_state_batch_coord < num_cache_lines)
@@ -897,7 +899,10 @@ def causal_conv1d_update(
stride_state_indices = (
conv_state_indices.stride(0) if conv_state_indices is not None else 0
)
state_len = width - 1 + (seqlen - 1) # effective state_len needed
if num_accepted_tokens is not None:
state_len = width - 1 + (seqlen - 1) # effective state_len needed
else:
state_len = width - 1
np2_statelen = triton.next_power_of_2(state_len)
def grid(META):

View File

@@ -1,23 +1,30 @@
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple
import torch
import torch.nn as nn
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.custom_op import CustomOp
from sglang.srt.configs.mamba_utils import (
Mamba2CacheParams,
extra_groups_for_head_shards,
)
from sglang.srt.distributed import (
divide,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.utils import divide
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
from sglang.srt.layers.attention.mamba.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_fn as causal_conv1d_fn_triton,
)
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
causal_conv1d_update as causal_conv1d_update_triton,
)
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
from sglang.srt.layers.attention.mamba.ops import (
mamba_chunk_scan_combined,
selective_state_update,
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.mem_cache.memory_pool import MambaPool
from sglang.srt.model_loader.weight_utils import (
composed_weight_loader,
sharded_weight_loader,
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
return loader
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return layernorm_fn(
x,
self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
)
class MambaMixer2(torch.nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
def __init__(
self,
cache_params: Mamba2CacheParams,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
use_conv_bias: bool,
use_bias: bool,
chunk_size: int,
layer_id: int,
n_groups: int = 1,
num_heads: int = 128,
head_dim: int = 64,
rms_norm_eps: float = 1e-5,
activation: str = "silu",
use_rms_norm: bool = True,
model_config: Optional[ModelConfig] = None,
# cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.num_heads = num_heads = cache_params.shape.num_heads
self.head_dim = cache_params.shape.head_dim
assert (
num_heads % self.tp_size == 0
), "Tensor parallel world size must divide num heads."
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
"then num_groups must equal 1."
)
self.ssm_state_size = ssm_state_size
self.conv_kernel_size = conv_kernel_size
assert (
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
), (
"Tensor parallel currently supported for quantized models only "
"if tensor parallel world size divides num groups."
)
self.ssm_state_size = cache_params.shape.ssm_state_size
self.activation = activation
self.layer_id = layer_id
self.intermediate_size = intermediate_size
self.head_dim = head_dim
self.num_heads = num_heads
self.chunk_size = chunk_size
conv_kernel_size = cache_params.shape.conv_kernel
self.intermediate_size = intermediate_size = (
cache_params.shape.intermediate_size
)
self.n_groups = n_groups
if n_groups % self.tp_size != 0:
# - for TP we shard conv_dim by sharding on n_groups,
# - but if n_groups cannot divide tp_size, we need to
# extend some extra groups
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
n_groups, self.tp_size
)
groups = extra_groups_for_head_shards(n_groups, self.tp_size)
self.n_groups = n_groups + groups
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
self.conv_dim = cache_params.shape.conv_dim
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
if n_groups % self.tp_size == 0:
self.conv1d = MergedColumnParallelLinear(
input_size=conv_kernel_size,
output_sizes=[
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
],
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
prefix=f"{prefix}.in_proj",
)
if n_groups % self.tp_size != 0:
self.in_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[
intermediate_size,
intermediate_size,
self.groups_ssm_state_size,
self.groups_ssm_state_size,
self.num_heads,
],
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
else:
# This is the n_groups == 1 case,
# where we need to duplicate groups if TP>1.
self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=self.conv_dim,
bias=use_conv_bias,
quant_config=None,
prefix=f"{prefix}.conv1d",
)
self.in_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias,
quant_config=quant_config,
prefix=f"{prefix}.in_proj",
)
# - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding
# - use the custom weight loader mamba_v2_sharded_weight_loader
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
)
# The tuple is (conv_state, ssm_state)
self.kv_cache = (torch.tensor([]), torch.tensor([]))
self.model_config = model_config
self.prefix = prefix
def forward_native(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mup_vector: Optional[torch.Tensor] = None,
):
pass
def forward(
self,
*,
hidden_states: torch.Tensor,
output: torch.Tensor,
forward_batch: ForwardBatch,
layer_cache: MambaPool.State,
metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
use_triton_causal_conv: bool = False,
):
# attn_backend_list[-1] gives access to MambaAttnBackend
mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
attn_metadata = mamba_backend.forward_metadata
state_indices_tensor = attn_metadata.mamba_cache_indices
chunk_size = self.chunk_size
# metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill
# modes; they are computed at top-level model forward since they
# stay the same and reused for all mamba layers in the same iteration
state_indices_tensor = metadata.mamba_cache_indices
conv_state = layer_cache.conv
ssm_state = layer_cache.temporal
conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
self.layer_id
)
assert (
ssm_state.size(1) == self.ssm_state_size
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
query_start_loc = attn_metadata.query_start_loc
chunk_size = self.chunk_size
# TODO: properly support this
prep_initial_states = False
query_start_loc = metadata.query_start_loc
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
dim=-1,
)
num_prefills = metadata.num_prefills # request count
num_decodes = metadata.num_decodes # token count (=request)
num_prefill_tokens = metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
assert num_actual_tokens == projected_states.shape[0]
# NOTE: V0 put prefill before decode
# Separate prefill and decode by splitting varlen input
# Split along token dimension
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
hidden_states_B_C,
[num_prefill_tokens, num_decodes],
dim=0,
)
dt_p, dt_d = torch.split(
dt,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Split along batch dimension
state_indices_tensor_p, state_indices_tensor_d = torch.split(
state_indices_tensor,
[num_prefills, num_decodes],
dim=0,
)
query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
preallocated_ssm_out = torch.empty(
[
projected_states.shape[0],
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
dtype=hidden_states.dtype,
device=hidden_states.device,
)
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
preallocated_ssm_out,
[num_prefill_tokens, num_decodes],
dim=0,
)
# Process prefill requests
if forward_batch.forward_mode.is_extend():
if has_prefill:
mixed_metadata = metadata.mixed_metadata
assert mixed_metadata is not None
# 2. Convolution sequence transformation
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "state_indices_tensor"
num_prefill_tokens = forward_batch.extend_num_tokens or 0
has_initial_states = forward_batch.extend_prefix_lens > 0
cache_indices = attn_metadata.mamba_cache_indices
x = hidden_states_B_C.transpose(
has_initial_states_p = mixed_metadata.has_initial_states
prep_initial_states = mixed_metadata.prep_initial_states
cache_indices = state_indices_tensor_p
x = hidden_states_B_C_p.transpose(
0, 1
) # this is the form that causal-conv see
hidden_states_B_C = causal_conv1d_fn(
ccfn = (
causal_conv1d_fn
if not use_triton_causal_conv
else causal_conv1d_fn_triton
)
hidden_states_B_C_p = ccfn(
x,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=has_initial_states,
has_initial_state=has_initial_states_p,
cache_indices=cache_indices,
query_start_loc=query_start_loc,
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
).transpose(0, 1)
query_start_loc=query_start_loc_p,
seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
).transpose(0, 1)[:num_prefill_tokens]
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
# 3. State Space Model sequence transformation
initial_states = None
if has_initial_states is not None and prep_initial_states:
if has_initial_states_p is not None and prep_initial_states:
initial_states = torch.where(
has_initial_states[:, None, None, None],
ssm_state[state_indices_tensor],
has_initial_states_p[:, None, None, None],
ssm_state[state_indices_tensor_p],
0,
)
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states.view(
hidden_states_p.view(
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
),
dt.unsqueeze(0),
dt_p.unsqueeze(0),
self.A,
B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=chunk_size,
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
chunk_size=mixed_metadata.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
cu_seqlens=query_start_loc,
seq_idx=mixed_metadata.seq_idx,
chunk_indices=mixed_metadata.chunk_indices,
chunk_offsets=mixed_metadata.chunk_offsets,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
out=preallocated_ssm_out_p.view(
1, num_prefill_tokens, -1, self.head_dim
),
state_dtype=ssm_state.dtype,
)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
elif forward_batch.forward_mode.is_decode():
num_decodes = len(query_start_loc) - 1
ssm_state[state_indices_tensor_p] = varlen_state
# Process decode requests
if has_decode:
# 2. Convolution sequence transformation
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
ccu = (
causal_conv1d_update
if not use_triton_causal_conv
else causal_conv1d_update_triton
)
hidden_states_B_C_d = ccu(
hidden_states_B_C_d,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=state_indices_tensor,
conv_state_indices=state_indices_tensor_d,
)
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
# 3. State Space Model sequence transformation
n_groups = self.n_groups // self.tp_size
A = (
A_d = (
self.A[:, None, ...][:, :, None]
.expand(-1, self.head_dim, self.ssm_state_size)
.to(dtype=torch.float32)
)
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(-1, n_groups, B.shape[1] // n_groups)
C = C.view(-1, n_groups, C.shape[1] // n_groups)
hidden_states = hidden_states.view(
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
hidden_states_d = hidden_states_d.view(
-1, self.num_heads // self.tp_size, self.head_dim
)
# - the hidden is reshaped into (bs, num_heads, head_dim)
# - mamba_cache_params.ssm_state's slots will be selected
# - layer_state.ssm_state's slots will be selected
# using state_indices_tensor_d
# NOTE: final output is an in-place update of out tensor
selective_state_update(
ssm_state.permute(0, 3, 2, 1),
hidden_states,
dt,
A,
B,
C,
D,
ssm_state,
hidden_states_d,
dt_d,
A_d,
B_d,
C_d,
D_d,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices_tensor,
out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
state_batch_indices=state_indices_tensor_d,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
)
elif forward_batch.forward_mode.is_idle():
preallocated_ssm_out = preallocated_ssm_out
# 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(preallocated_ssm_out, gate)
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
# 5. Final linear projection
output[:], _ = self.out_proj(hidden_states)
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
@property
def mamba_type(self) -> str:

View File

@@ -0,0 +1,211 @@
# Copyright 2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
import math
from dataclasses import dataclass
import torch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@dataclass(kw_only=True)
class ForwardMetadata:
query_start_loc: torch.Tensor
mamba_cache_indices: torch.Tensor
@dataclass(kw_only=True)
class Mamba2Metadata(ForwardMetadata):
"""stable metadata across all mamba2 layers in the forward pass"""
num_prefills: int
num_prefill_tokens: int
num_decodes: int
@dataclass(kw_only=True, frozen=True)
class MixedMetadata:
has_initial_states: torch.Tensor
prep_initial_states: bool
chunk_size: int
seq_idx: torch.Tensor
chunk_indices: torch.Tensor
chunk_offsets: torch.Tensor
extend_seq_lens_cpu: list[int]
mixed_metadata: MixedMetadata | None = None
"""`mixed_metadata` is used for extend/mixed requests"""
@staticmethod
def _query_start_loc_to_chunk_indices_offsets(
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Args:
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
lengths, shape (num_seqs + 1,).
The first element should be 0. Each entry represents the starting
index of a sequence in the flattened token array.
chunk_size (int): The size of each physical mamba chunk
(number of tokens per chunk).
total_seqlens (int): The total number of tokens in the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- chunk_indices (torch.Tensor): 1D tensor of indices
indicating the physical chunk for each logical chunk.
- chunk_offsets (torch.Tensor): 1D tensor of offsets
indicating the starting index of each logical chunk within
its physical chunk.
This function computes the chunk indices and offsets for the given
query_start_loc and chunk_size. Both are tensors of integers with length N,
where N is the number of logical (pseudo) chunks.
A logical chunk is a sequence of tokens that are all part of the same
sequence and are all in the same physical mamba chunk.
In other words, a logical chunk changes every time we cross a sequence
boundary or a physical mamba chunk boundary.
Logical chunks are needed to handle batched requests with initial states
(see _state_passing_fwd and _chunk_scan_fwd).
The chunk_indices tensor contains the index of the physical chunk for each
logical chunk.
The chunk_offsets tensor contains the offset (AKA starting index) of the
logical chunk in the physical chunk.
Example:
query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 0, 1]
-> chunk_offsets = [0, 5, 0]
In this example, we have 2 sequences, each with 5 tokens. The physical
chunk size is 8 tokens.
We have three logical chunks:
- the first logical chunk starts at token 0 in the first physical chunk
and contains all 5 tokens from the first sequence
- the second logical chunk starts at token 5 in the first physical chunk
and contains first 3 tokens from the second sequence
- the third logical chunk starts at token 0 in the second physical chunk
and contains the remaining 2 tokens from the second sequence
"""
cu_seqlens = query_start_loc[1:] # remove prepended 0
# outputs will have length expansion of chunks that do not divide
# chunk_size
N = (
math.ceil(total_seqlens / chunk_size)
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
)
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
chunk_offsets = torch.zeros(
(N,), dtype=torch.int, device=query_start_loc.device
)
p = 0 # num of insertions
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
# if does not divide chunk_size, then there is one chunk insertion
p += s % chunk_size > 0
# get the dimensions
# - the + 1 for _e is to shift the boundary by one chunk
# - this shifting is not needed if chunk_size divides e
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
# adjust indices and offsets
chunk_indices[_s:_e] -= p
chunk_offsets[_s] = s % chunk_size
return chunk_indices, chunk_offsets
@staticmethod
def prepare_decode(
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
seq_lens: torch.Tensor,
) -> "Mamba2Metadata":
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_decodes=len(seq_lens),
num_prefills=0,
num_prefill_tokens=0,
)
@classmethod
def prepare_mixed(
cls,
query_start_loc: torch.Tensor,
mamba_cache_indices: torch.Tensor,
chunk_size: int,
forward_batch: ForwardBatch,
) -> "Mamba2Metadata":
"""This path cannot run with CUDA graph, as it contains extend requests."""
if forward_batch.extend_num_tokens is None:
return cls.prepare_decode(
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
)
num_prefills = len(forward_batch.extend_seq_lens)
num_prefill_tokens = forward_batch.extend_num_tokens
num_decodes = len(forward_batch.seq_lens) - num_prefills
context_lens_tensor = forward_batch.extend_prefix_lens
assert context_lens_tensor is not None
# precompute flag to avoid device syncs later
has_initial_states = context_lens_tensor > 0
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
query_start_loc = query_start_loc[: num_prefills + 1]
seq_idx = torch.repeat_interleave(
torch.arange(
num_prefills, dtype=torch.int32, device=query_start_loc.device
),
query_start_loc.diff(),
output_size=num_prefill_tokens,
)
seq_idx.unsqueeze_(0)
# We compute metadata for chunked prefill once at the top level model
# forward and reuse them in mamba layers. If not needed, they will be
# ignored inside mamba kernels.
chunk_offsets, chunk_indices = None, None
if prep_initial_states:
chunk_indices, chunk_offsets = (
cls._query_start_loc_to_chunk_indices_offsets(
query_start_loc, chunk_size, num_prefill_tokens
)
)
return Mamba2Metadata(
query_start_loc=query_start_loc,
mamba_cache_indices=mamba_cache_indices,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
num_decodes=num_decodes,
mixed_metadata=cls.MixedMetadata(
has_initial_states=has_initial_states,
prep_initial_states=prep_initial_states,
chunk_size=chunk_size,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
),
)

View File

@@ -1,81 +0,0 @@
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
from sglang.srt.distributed.utils import divide
class MambaStateShapeCalculator:
@classmethod
def linear_attention_state_shape(
cls,
num_heads: int,
tp_size: int,
head_dim: int,
) -> tuple[tuple[int, int, int], ...]:
state_shape = (num_heads // tp_size, head_dim, head_dim)
return (state_shape,)
@classmethod
def mamba1_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int]]:
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
return conv_state_shape, temporal_state_shape
@classmethod
def mamba2_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
n_groups: int,
num_heads: int,
head_dim: int,
state_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int], tuple[int, int, int]]:
# if n_groups is not divisible by world_size, need to extend the shards
# to ensure all groups needed by a head is sharded along with it
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
# heads and n_groups are TP-ed
conv_dim = intermediate_size + 2 * n_groups * state_size
# contiguous along 'dim' axis
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
# These are not TP-ed as they depend on A, dt_bias, D
# - they are typically small
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
return conv_state_shape, temporal_state_shape
@classmethod
def short_conv_state_shape(
cls,
tp_world_size: int,
intermediate_size: int,
conv_kernel: int,
) -> tuple[tuple[int, int]]:
conv_dim = divide(intermediate_size, tp_world_size)
conv_state_shape = (conv_kernel - 1, conv_dim)
return (conv_state_shape,)
@classmethod
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
"""Compute the increase in group numbers to account for
replication in order to accompany the head shards."""
# in the case ngoups % tp_size == 0, this will be zero
if ngroups % tp_size == 0:
return 0
# for n_groups == 1, this is exactly tp_size - n_groups
return tp_size - ngroups

View File

@@ -0,0 +1,120 @@
from typing import Union
import torch
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed.communication_op import (
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
from sglang.srt.utils.common import set_weight_attrs
class Mixer2RMSNormGated(CustomOp):
def __init__(
self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6,
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.full_hidden_size = full_hidden_size
self.group_size = full_hidden_size // full_n_groups
self.per_rank_hidden_size = full_hidden_size // self.tp_size
self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
else:
# Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (
self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native(
self,
x: torch.Tensor,
gate: torch.Tensor,
):
# Three tensor-parallel cases:
# 1. n_groups is 1
# In this case we parallelize along the reduction dim.
# Each rank computes a local sum of squares followed by AllReduce
# 2. tp_size divides n_groups
# Each rank only reduces within its local group(s).
# No collective ops necessary.
# 3. The general case can be pretty complicated so we AllGather
# the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype
x = x * torch.nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1:
if self.tp_size > 1:
# Compute local sum and then reduce to obtain global sum
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance
count = self.tp_size * x.shape[-1]
variance = global_sums / count
else:
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
else:
redundant_tp: bool = self.n_groups % self.tp_size != 0
if redundant_tp:
# To handle the general case, redundantly apply the variance
x = tensor_model_parallel_all_gather(x, -1)
*prefix_dims, hidden_dim = x.shape
group_count = hidden_dim // self.group_size
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
variance = x_grouped.pow(2).mean(-1, keepdim=True)
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
x = x_grouped.view(*prefix_dims, hidden_dim)
if redundant_tp:
start = self.per_rank_hidden_size * self.tp_rank
end = start + self.per_rank_hidden_size
x = x[..., start:end]
return self.weight * x.to(input_dtype)
def forward_cuda(
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
return self.forward_native(x, gate)
return rms_norm_gated(
x=x,
weight=self.weight.data,
bias=None,
z=gate,
eps=self.variance_epsilon,
norm_before_gate=False,
is_rms_norm=True,
)

View File

@@ -15,56 +15,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "K", "IS_CAUSAL"],
# )
@triton.jit
def _bmm_chunk_fwd_kernel(
# Pointers to matrices

View File

@@ -16,66 +16,6 @@ from packaging import version
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
# )
@triton.jit
def _chunk_scan_fwd_kernel(
# Pointers to matrices

View File

@@ -17,17 +17,6 @@ import triton.language as tl
from .mamba_ssm import softplus
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE_H": 2}),
# triton.Config({"BLOCK_SIZE_H": 4}),
# triton.Config({"BLOCK_SIZE_H": 8}),
# triton.Config({"BLOCK_SIZE_H": 16}),
# triton.Config({"BLOCK_SIZE_H": 32}),
# triton.Config({"BLOCK_SIZE_H": 64}),
# ],
# key=["chunk_size", "nheads"],
# )
@triton.jit
def _chunk_cumsum_fwd_kernel(
# Pointers to matrices
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit
def _chunk_state_fwd_kernel(
# Pointers to matrices
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
tl.store(states_ptrs, states, mask=c_mask)
# @triton.autotune(
# configs=[
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
# num_stages=3,
# num_warps=8,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=4,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=5,
# num_warps=2,
# ),
# triton.Config(
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
# num_stages=4,
# num_warps=2,
# ),
# ],
# key=["hdim", "dstate", "chunk_size"],
# )
@triton.jit
def _chunk_state_varlen_kernel(
# Pointers to matrices

View File

@@ -13,17 +13,6 @@ import triton
import triton.language as tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_SIZE": 64}),
# triton.Config({"BLOCK_SIZE": 128}),
# triton.Config({"BLOCK_SIZE": 256}),
# triton.Config({"BLOCK_SIZE": 512}),
# triton.Config({"BLOCK_SIZE": 1024}),
# triton.Config({"BLOCK_SIZE": 2048}),
# ],
# key=["dim"],
# )
@triton.jit
def _state_passing_fwd_kernel(
# Pointers to matrices

View File

@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
if model_runner.is_hybrid_gdn:
if model_runner.hybrid_gdn_config is not None:
# For hybrid linear models, layer_id = 0 may not be full attention
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
else:

View File

@@ -1770,7 +1770,7 @@ class Scheduler(
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
# chunked request keeps its rid but will get a new req_pool_idx
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
if self.tp_worker.worker.model_runner.mambaish_config is not None:
self.req_to_token_pool.free(
self.chunked_req.req_pool_idx, free_mamba_cache=False
)

View File

@@ -15,6 +15,9 @@ limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
@@ -109,17 +112,38 @@ class ReqToTokenPool:
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_conv_window: torch.Tensor
def __init__(
self,
*,
size: int,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
num_mamba_layers: int,
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
cache_params: "Mamba2CacheParams",
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
@@ -158,11 +182,11 @@ class MambaPool:
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = (
conv_state,
temporal_state,
intermediate_ssm_state_cache,
intermediate_conv_window_cache,
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
@@ -172,7 +196,7 @@ class MambaPool:
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
self.mamba_cache = (conv_state, temporal_state)
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
@@ -180,16 +204,14 @@ class MambaPool:
)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.get_mamba_size() / GB
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_mamba_params_all_layers(self):
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
assert isinstance(self.mamba_cache, self.SpeculativeState)
return self.mamba_cache
def get_mamba_params(self, layer_id: int):
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
def get_mamba_size(self):
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
def mamba2_layer_cache(self, layer_id: int):
return self.mamba_cache.at_layer_idx(layer_id)
def available_size(self):
return len(self.free_slots)
@@ -208,7 +230,9 @@ class MambaPool:
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
def clear(self):
self.free_slots = list(range(self.size))
@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
def __init__(
self,
*,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
conv_dtype: torch.dtype,
ssm_dtype: torch.dtype,
mamba_layers: List[int],
conv_state_shape: Tuple[int, int],
temporal_state_shape: Tuple[int, int],
speculative_num_draft_tokens: int,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int = None,
):
super().__init__(
size=size,
@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
)
self.mamba_pool = MambaPool(
size,
conv_dtype,
ssm_dtype,
len(mamba_layers),
conv_state_shape,
temporal_state_shape,
device,
speculative_num_draft_tokens,
size=size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices]
def get_mamba_params(self, layer_id: int):
def mamba2_layer_cache(self, layer_id: int):
assert layer_id in self.mamba_map
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
def get_mamba_params_all_layers(self):
return self.mamba_pool.get_mamba_params_all_layers()
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):

View File

@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import (
@@ -354,8 +355,9 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
if config := self.mambaish_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
@@ -364,6 +366,7 @@ class ModelRunner:
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
@@ -1267,8 +1270,8 @@ class ModelRunner:
"num_nextn_predict_layers",
self.num_effective_layers,
)
elif self.is_hybrid_gdn:
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
elif config := self.mambaish_config:
num_layers = len(config.full_attention_layer_ids)
else:
num_layers = self.num_effective_layers
if self.use_mla_backend:
@@ -1288,22 +1291,32 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if self.is_hybrid_gdn:
if config := self.mambaish_config:
rest_memory -= (
self.server_args.max_mamba_cache_size
* self.model_config.hf_config.mamba_cache_per_req
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
@property
def is_hybrid_gdn(self):
return self.model_config.hf_config.architectures[0] in [
"Qwen3NextForCausalLM",
"Qwen3NextForCausalLMMTP",
"FalconH1ForCausalLM",
]
def hybrid_gdn_config(self):
config = self.model_config.hf_config
if isinstance(config, Qwen3NextConfig):
return config
return None
@property
def mamba2_config(self):
config = self.model_config.hf_config
if isinstance(config, FalconH1Config | NemotronHConfig):
return config
return None
@property
def mambaish_config(self):
return self.mamba2_config or self.hybrid_gdn_config
def set_num_token_hybrid(self):
if (
@@ -1438,7 +1451,7 @@ class ModelRunner:
),
4096,
)
if self.is_hybrid_gdn:
if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
@@ -1519,26 +1532,14 @@ class ModelRunner:
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif self.is_hybrid_gdn:
config = self.model_config.hf_config
(
conv_state_shape,
temporal_state_shape,
conv_dtype,
ssm_dtype,
mamba_layers,
) = config.hybrid_gdn_params
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
conv_state_shape=conv_state_shape,
temporal_state_shape=temporal_state_shape,
conv_dtype=conv_dtype,
ssm_dtype=ssm_dtype,
mamba_layers=mamba_layers,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
else:
@@ -1640,7 +1641,7 @@ class ModelRunner:
enable_kvcache_transpose=False,
device=self.device,
)
elif self.is_hybrid_gdn:
elif config := self.mambaish_config:
self.token_to_kv_pool = HybridLinearKVPool(
page_size=self.page_size,
size=self.max_total_num_tokens,
@@ -1651,9 +1652,7 @@ class ModelRunner:
head_dim=self.model_config.head_dim,
# if draft worker, we only need 1 attention layer's kv pool
full_attention_layer_ids=(
[0]
if self.is_draft_worker
else self.model_config.hf_config.full_attention_layer_ids
[0] if self.is_draft_worker else config.full_attention_layer_ids
),
enable_kvcache_transpose=False,
device=self.device,
@@ -1681,7 +1680,8 @@ class ModelRunner:
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
self.server_args.attention_backend == "ascend"
or self.hybrid_gdn_config is not None
):
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,

View File

@@ -8,6 +8,10 @@ from torch import nn
from sglang.srt.configs.falcon_h1 import FalconH1Config
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
)
self.mamba = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
ssm_state_size=config.mamba_d_state,
conv_kernel_size=config.mamba_d_conv,
intermediate_size=self.d_ssm,
use_conv_bias=config.mamba_conv_bias,
use_bias=config.mamba_proj_bias,
n_groups=config.mamba_n_groups,
num_heads=config.mamba_n_heads,
layer_id=layer_id,
head_dim=config.mamba_d_head,
rms_norm_eps=config.rms_norm_eps,
chunk_size=config.mamba_chunk_size,
activation=config.hidden_act,
use_rms_norm=config.mamba_rms_norm,
prefix=f"{prefix}.mixer",
@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
)
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
# Mamba block
mamba_hidden_states = torch.empty_like(hidden_states)
self.mamba(
attn_backend.linear_attn_backend.forward(
self.mamba,
hidden_states * self.ssm_in_multiplier,
mamba_hidden_states,
forward_batch=forward_batch,
layer_id=self.layer_id,
mup_vector=self.mup_vector,
)
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier

View File

@@ -0,0 +1,514 @@
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
"""Inference-only NemotronH model."""
from collections.abc import Iterable
from typing import Optional, Union
import torch
from torch import nn
from sglang.srt.configs import NemotronHConfig
from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import ReLU2
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
HybridLinearAttnBackend,
Mamba2AttnBackend,
)
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE,
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from sglang.srt.utils import add_prefix, make_layers_non_pp
from sglang.utils import logger
class NemotronHMLP(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
prefix: str = "",
) -> None:
super().__init__()
hybrid_override_pattern = config.hybrid_override_pattern
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
if isinstance(config.intermediate_size, list):
if len(config.intermediate_size) == 1:
intermediate_size = config.intermediate_size[0]
else:
intermediate_size = config.intermediate_size[mlp_index]
else:
intermediate_size = config.intermediate_size
self.up_proj = ColumnParallelLinear(
input_size=config.hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=config.hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj",
)
self.act_fn = ReLU2()
def forward(self, x: torch.Tensor):
x, _ = self.up_proj(x)
x = self.act_fn(x)
x, _ = self.down_proj(x)
return x
class NemotronHMLPDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.mixer = NemotronHMLP(
config,
quant_config=quant_config,
bias=config.mlp_bias,
prefix=f"{prefix}.mixer",
layer_idx=layer_idx,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(hidden_states)
return hidden_states, residual
class NemotronHMambaDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.config = config
self.layer_id = layer_idx
self.mixer = MambaMixer2(
cache_params=config.mamba2_cache_params,
hidden_size=config.hidden_size,
use_conv_bias=config.use_conv_bias,
use_bias=config.use_bias,
n_groups=config.mamba_n_groups,
rms_norm_eps=config.rms_norm_eps,
activation=config.mamba_hidden_act,
quant_config=quant_config,
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
output = torch.empty_like(hidden_states)
attn_backend = forward_batch.attn_backend
assert isinstance(attn_backend, HybridLinearAttnBackend)
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
attn_backend.linear_attn_backend.forward(
mixer=self.mixer,
layer_id=self.layer_id,
hidden_states=hidden_states,
output=output,
use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
)
return output, residual
class NemotronHAttention(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
if hasattr(config, "head_dim") and config.head_dim is not None:
self.head_dim = config.head_dim
else:
self.head_dim = config.hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
config.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
quant_config=quant_config,
prefix=add_prefix("attn", prefix),
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn.forward(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
class NemotronHAttentionDecoderLayer(nn.Module):
def __init__(
self,
config: NemotronHConfig,
layer_idx: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.mixer = NemotronHAttention(
config,
layer_idx,
quant_config,
prefix=f"{prefix}.mixer",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
*,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
forward_batch: ForwardBatch,
) -> tuple[torch.Tensor, torch.Tensor]:
if residual is None:
residual = hidden_states
hidden_states = self.norm(hidden_states)
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer.forward(
hidden_states=hidden_states, forward_batch=forward_batch
)
return hidden_states, residual
Layers = (
NemotronHAttentionDecoderLayer
| NemotronHMLPDecoderLayer
| NemotronHMambaDecoderLayer
)
ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
ATTENTION: NemotronHAttentionDecoderLayer,
MLP: NemotronHMLPDecoderLayer,
MAMBA: NemotronHMambaDecoderLayer,
}
class NemotronHModel(nn.Module):
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
def get_layer(idx: int, prefix: str):
layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
self.layers = make_layers_non_pp(
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
)
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
residual = None
for layer in self.layers:
if not isinstance(layer, Layers):
raise ValueError(f"Unknown layer type: {type(layer)}")
hidden_states, residual = layer.forward(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
if not get_pp_group().is_last_rank:
return PPProxyTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm_f(hidden_states, residual)
return hidden_states
class NemotronHForCausalLM(nn.Module):
remap_prefix = {"backbone": "model"}
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
# LoRA specific attributes
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = ["lm_head"]
def __init__(
self,
*,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
lora_config = None
self.config = config
self.model = self._init_model(
config=config, quant_config=quant_config, prefix=prefix
)
if self.config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=(
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config
else lora_config.lora_vocab_padding_size
),
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
def _init_model(
self,
config: NemotronHConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: Optional[torch.Tensor] = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
hidden_states = self.model.forward(
input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
updated_weights = []
for name, loaded_weight in weights:
for prefix, new_key in self.remap_prefix.items():
if name.startswith(prefix):
name = name.replace(prefix, new_key)
for substr, new_key in self.remap_substr.items():
if substr in name:
name = name.replace(substr, new_key)
updated_weights.append((name, loaded_weight))
params_dict = dict(self.named_parameters())
for name, loaded_weight in updated_weights:
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
else:
logger.warning(f"Parameter {name} not found in params_dict")
EntryClass = [NemotronHForCausalLM]

View File

@@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker):
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
# QQ: can be optimized
if self.target_worker.model_runner.is_hybrid_gdn:
if self.target_worker.model_runner.hybrid_gdn_config is not None:
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = (
torch.tensor(

View File

@@ -518,6 +518,24 @@ def make_layers(
return modules, start_layer, end_layer
def make_layers_non_pp(
num_hidden_layers: int,
layer_fn: LayerFn,
prefix: str = "",
) -> torch.nn.ModuleList:
from sglang.srt.offloader import get_offloader
layers = torch.nn.ModuleList(
get_offloader().wrap_modules(
(
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
for idx in range(num_hidden_layers)
)
)
)
return layers
cmo_stream = None

View File

@@ -45,6 +45,7 @@ from sglang.srt.configs import (
KimiVLConfig,
LongcatFlashConfig,
MultiModalityConfig,
NemotronHConfig,
Qwen3NextConfig,
Step3VLConfig,
)
@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
FalconH1Config.model_type: FalconH1Config,
DotsVLMConfig.model_type: DotsVLMConfig,
DotsOCRConfig.model_type: DotsOCRConfig,
NemotronHConfig.model_type: NemotronHConfig,
}
for name, cls in _CONFIG_REGISTRY.items():

View File

@@ -0,0 +1,375 @@
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py
from typing import Optional
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
PAD_SLOT_ID,
causal_conv1d_fn,
causal_conv1d_update,
)
def causal_conv1d_ref(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
initial_states: Optional[torch.Tensor] = None,
return_final_states: bool = False,
final_states_out: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
x = x.to(weight.dtype)
seqlen = x.shape[-1]
dim, width = weight.shape
if initial_states is None:
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
return (out, None) if not return_final_states else (out, final_states_out)
def causal_conv1d_update_ref(
x, conv_state, weight, bias=None, activation=None, cache_seqlens=None
):
"""
x: (batch, dim) or (batch, dim, seqlen)
conv_state: (batch, dim, state_len), where state_len >= width - 1
weight: (dim, width)
bias: (dim,)
cache_seqlens: (batch,), dtype int32.
If not None, the conv_state is treated as a circular buffer.
The conv_state will be updated by copying x to the
conv_state starting at the index
@cache_seqlens % state_len before performing the convolution.
out: (batch, dim) or (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
dtype_in = x.dtype
unsqueeze = x.dim() == 2
if unsqueeze:
x = x.unsqueeze(-1)
batch, dim, seqlen = x.shape
width = weight.shape[1]
state_len = conv_state.shape[-1]
assert conv_state.shape == (batch, dim, state_len)
assert weight.shape == (dim, width)
if cache_seqlens is None:
x_new = torch.cat([conv_state, x], dim=-1).to(
weight.dtype
) # (batch, dim, state_len + seqlen)
conv_state.copy_(x_new[:, :, -state_len:])
else:
width_idx = torch.arange(
-(width - 1), 0, dtype=torch.long, device=x.device
).unsqueeze(0) + cache_seqlens.unsqueeze(1)
width_idx = (
torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
)
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(
0
) + cache_seqlens.unsqueeze(1)
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
conv_state.scatter_(2, copy_idx, x)
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[
:, :, -seqlen:
]
if unsqueeze:
out = out.squeeze(-1)
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
def causal_conv1d_opcheck_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
cu_seq_len: Optional[torch.Tensor] = None,
cache_indices: Optional[torch.Tensor] = None,
has_initial_state: Optional[torch.Tensor] = None,
conv_states: Optional[torch.Tensor] = None,
activation: Optional[str] = "silu",
pad_slot_id: int = PAD_SLOT_ID,
):
"""
x: (batch, dim, seqlen)
weight: (dim, width)
bias: (dim,)
seq_idx: (batch, seqlen)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1), to be written to
activation: either None or "silu" or "swish"
out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(-1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
batch = 2
x = torch.randn(batch, dim, seqlen, device=device, dtype=itype)
x_ref = x.clone()
conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype)
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state.detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation)
out_ref = causal_conv1d_update_ref(
x_ref, conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state, conv_state_ref)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3])
@pytest.mark.parametrize("width", [3, 4])
@pytest.mark.parametrize("dim", [2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather(
batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(
padded_batch_size, seqlen, dim, device=device, dtype=itype
).transpose(1, 2)
x_ref = x.clone()
conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat(
[
conv_state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(
total_entries, width - 1, dim, device=device, dtype=itype
).transpose(1, 2)
conv_state_for_padding_test = conv_state.clone()
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"
out = causal_conv1d_update(
x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = causal_conv1d_update_ref(
x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation
)
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(
conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool]
)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096])
@pytest.mark.parametrize("dim", [64, 4096])
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch", [4, 10])
def test_causal_conv1d_varlen(
batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
torch.manual_seed(0)
seqlens = []
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
seqlens.append(
torch.diff(
torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])
).tolist()
)
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])
total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0)
x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype),
"b s d -> b d s",
)[:, 4096 : 4096 + dim, :]
weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
final_states = torch.randn(
total_entries, width - 1, dim, device=x.device, dtype=x.dtype
).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(
0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device
)
state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[
:batch_size
]
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1,
)
out = causal_conv1d_fn(
x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
seq_lens_cpu=torch.tensor(seqlens[0]),
cache_indices=padded_state_indices,
has_initial_state=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID,
)
out_ref = []
out_ref_b = []
splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0),
initial_states=(
final_states_ref[padded_state_indices[i]].unsqueeze(0)
if has_initial_states[i]
else None
),
)
)
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)
assert torch.allclose(
final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol,
)
unpadded_out = out[:, : out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)

View File

@@ -0,0 +1,138 @@
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py
from unittest.mock import patch
import pytest
import torch
from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
update_environment_variables,
)
from sglang.srt.distributed.parallel_state import (
init_distributed_environment,
initialize_model_parallel,
)
NUM_GPUS = 2
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize(
"hidden_size_n_groups",
[
(64, 1), # hidden_size be divisible by num_gpus
(100, 4), # and n_groups must divide hidden_size
],
)
@pytest.mark.parametrize("dtype", [torch.float16])
def test_mixer2_gated_norm_multi_gpu(
batch_size: int,
seq_len: int,
hidden_size_n_groups: tuple[int, int],
dtype: torch.dtype,
device: str = "cuda",
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
assert torch.cuda.device_count() == NUM_GPUS
hidden_size, n_groups = hidden_size_n_groups
num_processes = NUM_GPUS
def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(
fn,
args=(
num_processes,
batch_size,
seq_len,
hidden_size,
n_groups,
dtype,
device,
),
nprocs=nprocs,
)
run_torch_spawn(mixer2_gated_norm_tensor_parallel, NUM_GPUS)
def mixer2_gated_norm_tensor_parallel(
local_rank: int,
world_size: int,
batch_size: int,
seq_len: int,
hidden_size: int,
n_groups: int,
dtype: torch.dtype,
device: str,
):
torch.manual_seed(0)
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)
update_environment_variables(
{
"RANK": str(local_rank),
"LOCAL_RANK": str(local_rank),
"WORLD_SIZE": str(world_size),
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
}
)
# initialize distributed
init_distributed_environment(
world_size=world_size, rank=local_rank, local_rank=local_rank
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
# create random weights an inputs
weight = torch.rand((hidden_size,), dtype=dtype, device=device)
hidden_states = torch.randn(batch_size, seq_len, hidden_size)
gate_states = torch.randn(batch_size, seq_len, hidden_size)
import sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated as m2
import sglang.srt.model_loader.weight_utils as wu
# Convenience: Avoid calling initialize_dp_attention
with patch.object(wu, "get_attention_tp_rank", return_value=local_rank):
# create gated-norm with TP
mixer = m2.Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
mixer.weight.weight_loader(mixer.weight, weight)
with (
patch.object(m2, "get_tensor_model_parallel_world_size", return_value=1),
patch.object(m2, "get_tensor_model_parallel_rank", return_value=0),
):
# create gated-norm without TP to compute reference
mixer_single_gpu = m2.Mixer2RMSNormGated(
full_hidden_size=hidden_size,
full_n_groups=n_groups,
)
# assign weight to single-gpu mixer
mixer_single_gpu.weight.data = weight
# generate and compare
N = hidden_size // world_size
output = mixer(
hidden_states[..., local_rank * N : (local_rank + 1) * N],
gate_states[..., local_rank * N : (local_rank + 1) * N],
)
ref_output = mixer_single_gpu(hidden_states, gate_states)
torch.testing.assert_close(
output,
ref_output[..., local_rank * N : (local_rank + 1) * N],
atol=5e-3,
rtol=1e-3,
)

View File

@@ -0,0 +1,291 @@
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import PAD_SLOT_ID
from sglang.srt.layers.attention.mamba.ops import selective_state_update
def selective_state_update_ref(
state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False
):
"""
Argument:
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
x: (batch, dim) or (batch, nheads, dim)
dt: (batch, dim) or (batch, nheads, dim)
A: (dim, dstate) or (nheads, dim, dstate)
B: (batch, dstate) or (batch, ngroups, dstate)
C: (batch, dstate) or (batch, ngroups, dstate)
D: (dim,) or (nheads, dim)
z: (batch, dim) or (batch, nheads, dim)
dt_bias: (dim,) or (nheads, dim)
Return:
out: (batch, dim) or (batch, nheads, dim)
"""
has_heads = state.dim() > 3
if state.dim() == 3:
state = state.unsqueeze(1)
if x.dim() == 2:
x = x.unsqueeze(1)
if dt.dim() == 2:
dt = dt.unsqueeze(1)
if A.dim() == 2:
A = A.unsqueeze(0)
if B.dim() == 2:
B = B.unsqueeze(1)
if C.dim() == 2:
C = C.unsqueeze(1)
if D is not None and D.dim() == 1:
D = D.unsqueeze(0)
if z is not None and z.dim() == 2:
z = z.unsqueeze(1)
if dt_bias is not None and dt_bias.dim() == 1:
dt_bias = dt_bias.unsqueeze(0)
batch, nheads, dim, dstate = state.shape
assert x.shape == (batch, nheads, dim)
assert dt.shape == x.shape
assert A.shape == (nheads, dim, dstate)
ngroups = B.shape[1]
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
assert B.shape == (batch, ngroups, dstate)
assert C.shape == B.shape
if D is not None:
assert D.shape == (nheads, dim)
if z is not None:
assert z.shape == x.shape
if dt_bias is not None:
assert dt_bias.shape == (nheads, dim)
dt = dt + dt_bias
dt = F.softplus(dt) if dt_softplus else dt
dA = torch.exp(
rearrange(dt, "b h d -> b h d 1") * A
) # (batch, nheads, dim, dstate)
B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate)
dB = rearrange(dt, "b h d -> b h d 1") * rearrange(
B, "b h n -> b h 1 n"
) # (batch, nheads, dim, dstate)
state.copy_(
state * dA + dB * rearrange(x, "b h d -> b h d 1")
) # (batch, dim, dstate
out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C)
if D is not None:
out += (x * D).to(out.dtype)
out = (out if z is None else out * F.silu(z)).to(x.dtype)
if not has_heads:
out = out.squeeze(1)
return out
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
def test_selective_state_update(dim, dstate, has_z, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
if torch.version.hip:
atol *= 2
# set seed
torch.manual_seed(0)
batch_size = 1
state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device)
x = torch.randn(batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(batch_size, dstate, device=device)
C = torch.randn(batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state.detach().clone()
selective_state_update(
state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [True])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
def test_selective_state_update_with_batch_indices(
with_padding, dim, dstate, has_z, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
if torch.version.hip:
atol *= 2
# set seed
torch.random.manual_seed(0)
batch_size = 3
padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
total_entries = 10 * batch_size
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device)
unused_states_bool[state_indices] = False
padded_state_indices = torch.concat(
[
state_indices,
torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=0,
)
x = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
out = torch.empty_like(x)
dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype)
dt_bias = torch.rand(dim, device=device) - 4.0
A = -torch.rand(dim, dstate, device=device) - 1.0
B = torch.randn(padded_batch_size, dstate, device=device)
C = torch.randn(padded_batch_size, dstate, device=device)
D = torch.randn(dim, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].clone()
state_before = state.clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref,
x[:batch_size],
dt[:batch_size],
A,
B[:batch_size],
C[:batch_size],
D=D,
z=z[:batch_size],
dt_bias=dt_bias,
dt_softplus=True,
)
print("Output diff max", (out[:batch_size] - out_ref).max())
print("Output diff mean", (out[:batch_size] - out_ref).mean())
print("Output state diff max", (state[state_indices, :] - state_ref).max())
print("Output state diff mean", (state[state_indices, :] - state_ref).mean())
# test padded entries stay the same
if with_padding:
assert torch.equal(state_before[unused_states_bool], state[unused_states_bool])
assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :])
assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :])
assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :])
assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :])
# test "real" entries
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("has_z", [False, True])
@pytest.mark.parametrize("tie_hdim", [False, True])
@pytest.mark.parametrize("ngroups", [1, 2, 4])
@pytest.mark.parametrize("dstate", [16, 32, 64])
@pytest.mark.parametrize("dim", [2048, 4096])
def test_selective_state_update_with_heads_with_batch_indices(
dim, dstate, ngroups, has_z, tie_hdim, itype
):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
if itype == torch.bfloat16:
rtol, atol = 1e-1, 1e-1
# set seed
torch.random.manual_seed(0)
batch_size = 3
headdim = 64
nheads = dim // headdim
total_entries = 10 * batch_size
state = torch.randn(
total_entries, nheads, headdim, dstate, dtype=itype, device=device
)
state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device
)
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
out = torch.empty_like(x)
if not tie_hdim:
dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
D = torch.randn(nheads, headdim, device=device)
else:
dt = repeat(
torch.randn(batch_size, nheads, device=device, dtype=itype),
"b h -> b h p",
p=headdim,
)
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim)
A = repeat(
-torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate
)
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
B = torch.randn(batch_size, ngroups, dstate, device=device)
C = torch.randn(batch_size, ngroups, dstate, device=device)
z = torch.randn_like(x) if has_z else None
state_ref = state[state_indices, :].detach().clone()
selective_state_update(
state,
x,
dt,
A,
B,
C,
D=D,
z=z,
dt_bias=dt_bias,
dt_softplus=True,
state_batch_indices=state_indices,
pad_slot_id=PAD_SLOT_ID,
out=out,
)
out_ref = selective_state_update_ref(
state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True
)
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol)
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)

View File

@@ -0,0 +1,581 @@
# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined
# Added by the IBM Team, 2024
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py
# TODO: These take a long time to run - we should cut down on some of the parameterized matrix.
# this is the segsum implementation taken from above
def segsum(x):
"""Calculates segment sum."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
A: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
assert X.shape[1] % block_len == 0
# Rearrange into blocks/chunks
X, A, B, C = (
rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)
)
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at
# chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
# (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p")
return Y, final_state
def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
torch.manual_seed(0)
A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device))
dt = F.softplus(
torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4
)
X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device)
return A, dt, X, B, C
def generate_continuous_batched_examples(
example_lens_by_batch,
num_examples,
full_length,
last_taken,
exhausted,
n_heads,
d_head,
itype,
device="cuda",
return_naive_ref=True,
):
# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels.
# If if return_naive_ref=True, the naive torch implementation
# ssd_minimal_discrete will be used to compute and return
# reference output.
# generate the full-length example
A, dt, X, B, C = generate_random_inputs(
num_examples, full_length, n_heads, d_head, itype
)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4
)
# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
# e.g., example_lens=(8, 4) means take 8 samples from first eg,
# 4 examples from second eg, etc
def get_continuous_batch(example_lens: tuple[int, ...]):
indices = []
for i, x in enumerate(example_lens):
c = last_taken.get(i, 0)
indices.append((c, c + x))
last_taken[i] = (c + x) % full_length
exhausted[i] = last_taken[i] == 0
return (
torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0)
for x in (dt, X, B, C)
)
# internal function that maps "n" to the appropriate right boundary
# value when forming continuous batches from examples of length given
# by "full_length".
# - e.g., when n > full_length, returns n % full_length
# when n == full_length, returns full_length
def end_boundary(n: int):
return n - ((n - 1) // full_length) * full_length
IND_E = None
for spec in example_lens_by_batch:
# get the (maybe partial) example seen in this cont batch
dt2, X2, B2, C2 = get_continuous_batch(spec)
# get the metadata
cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0)
seq_idx = torch.zeros(
cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device
)
for i, (srt, end) in enumerate(
zip(
cu_seqlens,
cu_seqlens[1:],
)
):
seq_idx[srt:end] = i
# for cont batch
if IND_E is None:
IND_S = [0 for _ in range(len(spec))]
else:
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
yield (
(
[Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)]
if return_naive_ref
else None
),
cu_seqlens,
seq_idx.unsqueeze(0),
(A, dt2, X2, B2, C2),
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32])
@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128])
@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)])
def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
# this tests the kernels on a single example (no batching)
# TODO: the bfloat16 case requires higher thresholds. To be investigated
if itype == torch.bfloat16:
atol, rtol = 5e-2, 5e-2
else:
atol, rtol = 8e-3, 5e-3
# set seed
batch_size = 1 # batch_size
# ssd_minimal_discrete requires chunk_size divide seqlen
# - this is only required for generating the reference seqs,
# it is not an operational limitation.
seqlen, chunk_size = seq_len_chunk_size
A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype)
Y_min, final_state_min = ssd_minimal_discrete(
X * dt.unsqueeze(-1), A * dt, B, C, chunk_size
)
Y = torch.empty_like(X)
final_state = mamba_chunk_scan_combined(
X, dt, A, B, C, chunk_size, D=None, return_final_states=True, out=Y
)
# just test the last in sequence
torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol)
# just test the last head
# NOTE, in the kernel we always cast states to fp32
torch.testing.assert_close(
final_state[:, -1],
final_state_min[:, -1].to(torch.float32),
atol=atol,
rtol=rtol,
)
@pytest.mark.parametrize("itype", [torch.float32, torch.float16])
@pytest.mark.parametrize("n_heads", [4, 8, 13])
@pytest.mark.parametrize("d_head", [5, 16, 21, 32])
@pytest.mark.parametrize(
"seq_len_chunk_size_cases",
[
# small-ish chunk_size (8)
(64, 8, 2, [(64, 32), (64, 32)]),
(64, 8, 2, [(32, 32), (32, 32), (32, 32)]),
(64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary
(
64,
8,
2,
[(4, 4), (4, 4), (4, 4), (4, 4)],
), # chunk_size larger than cont batches
(
64,
8,
5,
[
(64, 32, 16, 8, 8),
(8, 16, 32, 16, 8),
(8, 8, 16, 32, 16),
],
), # mode examples with varied lengths
# large-ish chunk_size (256)
(64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences
(
64,
256,
2,
[(5, 30), (1, 2), (1, 2), (1, 2)],
), # irregular sizes with small sequences
# we also need to test some large seqlen
# to catch errors with init states decay
(768, 128, 2, [(138, 225), (138, 225)]),
],
)
def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
# this test with multiple examples in a continuous batch
# (i.e. chunked prefill)
seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases
# This test can have larger error for longer sequences
if seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
states = None
for (
Y_min,
cu_seqlens,
seq_idx,
(A, dt, X, B, C),
) in generate_continuous_batched_examples(
cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype
):
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]
)
)
Y = torch.empty_like(X)
new_states = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=states,
out=Y,
)
# just test the last in sequence
for i in range(num_examples):
# just test one dim and dstate
Y_eg = Y[0, cu_seqlens[i] : cu_seqlens[i + 1], 0, 0]
Y_min_eg = Y_min[i][:, 0, 0]
torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol)
# update states
states = new_states
for i, clear in exhausted.items():
if clear:
states[i].fill_(0.0)
exhausted[i] = False
@pytest.mark.parametrize("chunk_size", [8, 256])
@pytest.mark.parametrize(
"seqlens",
[
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
],
)
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
if not torch.cuda.is_available():
pytest.skip("CUDA device not available")
# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
# It is different from test_mamba_chunk_scan_cont_batch by:
# 1. Not using the naive torch implementation (ssd_minimal_discrete) to get
# reference outputs. Instead, it compares chunked kernel outputs to full
# sequence kernel outputs. This is the most straightforward way to
# assert chunked prefill correctness.
# 2. It focuses on cases where sequences change in the middle of mamba
# chunks, and not necessarily on chunk boundaries.
max_seqlen = max(seqlens)
# This test can have larger error for longer sequences
if max_seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3
num_sequences = len(seqlens)
n_heads = 16
d_head = 64
itype = torch.float32
# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples(
[seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False,
)
)
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device
## full seqlen computation
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1]
)
)
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_ref,
)
## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat(
[torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0
)
chunked_seq_idx = (
torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1],
)
.unsqueeze(0)
.to(torch.int32)
)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1]
)
)
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined(
X_chunked,
dt_chunked,
A,
B_chunked,
C_chunked,
chunk_size,
D=None,
cu_seqlens=chunked_cu_seqlens,
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_partial,
)
# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0),
],
dim=0,
)
remaining_chunked_seq_idx = (
torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1],
)
.unsqueeze(0)
.to(torch.int32)
)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=1)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
# fmt: on
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
chunk_indices, chunk_offsets = (
Mamba2Metadata._query_start_loc_to_chunk_indices_offsets(
remaining_chunked_cu_seqlens, chunk_size, remaining_chunked_cu_seqlens[-1]
)
)
Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined(
remaining_X_chunked,
remaining_dt_chunked,
A,
remaining_B_chunked,
remaining_C_chunked,
chunk_size,
D=None,
cu_seqlens=remaining_chunked_cu_seqlens,
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=partial_state,
out=Y_chunked,
)
Y = concat_batch_f(Y_partial, Y_chunked)
# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[:, cu_seqlens[i] : cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[:, cu_seqlens[i] : cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:, : chunked_seqlens[i], ...],
Y_ref_seq[:, : chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x,
) # noqa: B023
torch.testing.assert_close(
Y_seq[:, chunked_seqlens[i] :, ...],
Y_ref_seq[:, chunked_seqlens[i] :, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x,
) # noqa: B023
state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
torch.testing.assert_close(
state_seq,
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x,
) # noqa: B023

View File

@@ -91,6 +91,11 @@ ALL_MODELS = [
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase(
"nvidia/NVIDIA-Nemotron-Nano-9B-v2",
trust_remote_code=True,
skip_long_prompt=True,
),
ModelCase(
"swiss-ai/Apertus-8B",
trust_remote_code=True,

View File

@@ -0,0 +1,44 @@
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestNvidiaNemotronNanoV2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--max-mamba-cache-size",
"256",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval(args)
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.87)

View File

@@ -127,6 +127,10 @@ suites = {
TestFile("test_vlm_input_format.py", 300),
TestFile("test_vision_openai_server_a.py", 724),
TestFile("test_vision_openai_server_b.py", 446),
TestFile("layers/attention/mamba/test_causal_conv1d.py", 85),
TestFile("layers/attention/mamba/test_mamba_ssm.py", 85),
TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 220),
TestFile("models/test_nvidia_nemotron_nano_v2.py", 180),
TestFile("test_modelopt_loader.py", 30),
],
"per-commit-2-gpu": [
@@ -142,6 +146,7 @@ suites = {
TestFile("hicache/test_hicache_storage_file_backend.py", 200),
TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400),
TestFile("hicache/test_hicache_storage_3fs_backend.py", 200),
TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110),
],
"per-commit-4-gpu": [
TestFile("test_gpt_oss_4gpu.py", 300),