From d6837aea4d2c1e32b19706ecd4d807df82dacfce Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Wed, 8 Oct 2025 19:37:38 +0300 Subject: [PATCH] model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909) Signed-off-by: Netanel Haber --- docs/supported_models/generative_models.md | 5 +- python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/falcon_h1.py | 70 +-- python/sglang/srt/configs/mamba_utils.py | 117 ++++ python/sglang/srt/configs/nemotron_h.py | 286 +++++++++ python/sglang/srt/configs/qwen3_next.py | 54 +- .../layers/attention/attention_registry.py | 54 +- .../layers/attention/fla/layernorm_gated.py | 77 ++- .../attention/hybrid_linear_attn_backend.py | 224 +++++-- .../layers/attention/mamba/causal_conv1d.py | 2 +- .../attention/mamba/causal_conv1d_triton.py | 13 +- .../srt/layers/attention/mamba/mamba.py | 428 ++++++------- .../layers/attention/mamba/mamba2_metadata.py | 211 +++++++ .../srt/layers/attention/mamba/mamba_utils.py | 81 --- .../attention/mamba/mixer2_rms_norm_gated.py | 120 ++++ .../srt/layers/attention/mamba/ops/ssd_bmm.py | 50 -- .../attention/mamba/ops/ssd_chunk_scan.py | 60 -- .../attention/mamba/ops/ssd_chunk_state.py | 111 ---- .../attention/mamba/ops/ssd_state_passing.py | 11 - .../srt/layers/attention/triton_backend.py | 2 +- python/sglang/srt/managers/scheduler.py | 2 +- python/sglang/srt/mem_cache/memory_pool.py | 95 +-- .../sglang/srt/model_executor/model_runner.py | 64 +- python/sglang/srt/models/falcon_h1.py | 20 +- python/sglang/srt/models/nemotron_h.py | 514 ++++++++++++++++ python/sglang/srt/speculative/eagle_worker.py | 2 +- python/sglang/srt/utils/common.py | 18 + .../sglang/srt/utils/hf_transformers_utils.py | 2 + .../attention/mamba/test_causal_conv1d.py | 375 +++++++++++ .../attention/mamba/test_mamba2_mixer.py | 138 +++++ .../layers/attention/mamba/test_mamba_ssm.py | 291 +++++++++ .../attention/mamba/test_mamba_ssm_ssd.py | 581 ++++++++++++++++++ test/srt/models/test_generation_models.py | 5 + .../models/test_nvidia_nemotron_nano_v2.py | 44 ++ test/srt/run_suite.py | 5 + 35 files changed, 3280 insertions(+), 854 deletions(-) create mode 100644 python/sglang/srt/configs/mamba_utils.py create mode 100644 python/sglang/srt/configs/nemotron_h.py create mode 100644 python/sglang/srt/layers/attention/mamba/mamba2_metadata.py delete mode 100644 python/sglang/srt/layers/attention/mamba/mamba_utils.py create mode 100644 python/sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py create mode 100644 python/sglang/srt/models/nemotron_h.py create mode 100644 test/srt/layers/attention/mamba/test_causal_conv1d.py create mode 100644 test/srt/layers/attention/mamba/test_mamba2_mixer.py create mode 100644 test/srt/layers/attention/mamba/test_mamba_ssm.py create mode 100644 test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py create mode 100644 test/srt/models/test_nvidia_nemotron_nano_v2.py diff --git a/docs/supported_models/generative_models.md b/docs/supported_models/generative_models.md index f127ef118..7733bedc8 100644 --- a/docs/supported_models/generative_models.md +++ b/docs/supported_models/generative_models.md @@ -53,6 +53,7 @@ in the GitHub search bar. | **Ling** (16.8B–290B) | `inclusionAI/Ling-lite`, `inclusionAI/Ling-plus` | InclusionAI’s open MoE models. Ling-Lite has 16.8B total / 2.75B active parameters, and Ling-Plus has 290B total / 28.8B active parameters. They are designed for high performance on NLP and complex reasoning tasks. | | **Granite 3.0, 3.1** (IBM) | `ibm-granite/granite-3.1-8b-instruct` | IBM's open dense foundation models optimized for reasoning, code, and business AI use cases. Integrated with Red Hat and watsonx systems. | | **Granite 3.0 MoE** (IBM) | `ibm-granite/granite-3.0-3b-a800m-instruct` | IBM’s Mixture-of-Experts models offering strong performance with cost-efficiency. MoE expert routing designed for enterprise deployment at scale. | -| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. | -| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family builds on the strongest open models in the ecosystem by enhancing them with greater accuracy, efficiency, and transparency using NVIDIA open synthetic datasets, advanced techniques, and tools. This enables the creation of practical, right-sized, and high-performing AI agents. | +| **Llama Nemotron Super** (v1, v1.5, NVIDIA) | `nvidia/Llama-3_3-Nemotron-Super-49B-v1`, `nvidia/Llama-3_3-Nemotron-Super-49B-v1_5` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. | +| **Llama Nemotron Ultra** (v1, NVIDIA) | `nvidia/Llama-3_1-Nemotron-Ultra-253B-v1` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. | +| **NVIDIA Nemotron Nano 2.0** | `nvidia/NVIDIA-Nemotron-Nano-9B-v2` | The [NVIDIA Nemotron](https://www.nvidia.com/en-us/ai-data-science/foundation-models/nemotron/) family of multimodal models provides state-of-the-art reasoning models specifically designed for enterprise-ready AI agents. `Nemotron-Nano-9B-v2` is a hybrid Mamba-Transformer language model designed to increase throughput for reasoning workloads while achieving state-of-the-art accuracy compared to similarly-sized models. | | **StarCoder2** (3B-15B) | `bigcode/starcoder2-7b` | StarCoder2 is a family of open large language models (LLMs) specialized for code generation and understanding. It is the successor to StarCoder, jointly developed by the BigCode project (a collaboration between Hugging Face, ServiceNow Research, and other contributors). | diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 8a8a3bdeb..fb5a4d6d2 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig +from sglang.srt.configs.nemotron_h import NemotronHConfig from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.configs.step3_vl import ( Step3TextConfig, @@ -32,4 +33,5 @@ __all__ = [ "DotsVLMConfig", "DotsOCRConfig", "FalconH1Config", + "NemotronHConfig", ] diff --git a/python/sglang/srt/configs/falcon_h1.py b/python/sglang/srt/configs/falcon_h1.py index 368404bd0..d323b056d 100644 --- a/python/sglang/srt/configs/falcon_h1.py +++ b/python/sglang/srt/configs/falcon_h1.py @@ -15,16 +15,12 @@ """Falcon-H1 model configuration""" import enum -import os -import numpy as np -import torch from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging -from sglang.srt.distributed.utils import divide -from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.layers.dp_attention import ( get_attention_tp_size, get_tensor_model_parallel_world_size, @@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig): self.rope_scaling = None self.rope_scaling = rope_scaling self.projectors_bias = projectors_bias - mamba_intermediate = ( + self.mamba_intermediate = mamba_intermediate = ( mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm ) @@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig): def layers_block_type(self): return ["falcon_h1" for i in range(self.num_hidden_layers)] - @property - def mamba_cache_per_req(self): - conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = ( - self.hybrid_gdn_params - ) - mamba_layers_len = len(mamba_layers) - - return ( - int(np.prod(conv_state_shape)) * conv_dtype.itemsize - + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize - ) * mamba_layers_len - @property def full_attention_layer_ids(self): # For Falcon-H1, we do have attention on all layers @@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig): return range(self.num_hidden_layers) @property - def hybrid_gdn_params(self): - world_size = get_tensor_model_parallel_world_size() - - n_groups = self.mamba_n_groups - if self.mamba_n_groups % world_size != 0: - # - for TP we shard conv_dim by sharding on n_groups, - # - but if n_groups cannot divide tp_size, we need to - # extend some extra groups - extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards( - self.mamba_n_groups, world_size - ) - n_groups += extra_groups - - conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state - - conv_state_shape = ( - divide(conv_dim, world_size), - self.mamba_d_conv - 1, - ) - - # we TP-ize on the heads dimension - temporal_state_shape = ( - self.mamba_d_state, - self.mamba_d_head, - divide(self.mamba_n_heads, world_size), - ) - conv_dtype = torch.bfloat16 - dtype_map = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - } - ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] - mamba_layers = self.linear_layer_ids - - return ( - conv_state_shape, - temporal_state_shape, - conv_dtype, - ssm_dtype, - mamba_layers, + def mamba2_cache_params(self): + shape = Mamba2StateShape.create( + tp_world_size=get_tensor_model_parallel_world_size(), + intermediate_size=self.mamba_intermediate, + n_groups=self.mamba_n_groups, + num_heads=self.mamba_n_heads, + head_dim=self.mamba_d_head, + state_size=self.mamba_d_state, + conv_kernel=self.mamba_d_conv, ) + return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids) diff --git a/python/sglang/srt/configs/mamba_utils.py b/python/sglang/srt/configs/mamba_utils.py new file mode 100644 index 000000000..3199c0461 --- /dev/null +++ b/python/sglang/srt/configs/mamba_utils.py @@ -0,0 +1,117 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc.""" + +import os +from dataclasses import dataclass, field + +import numpy as np +import torch + +from sglang.srt.distributed.utils import divide + + +def extra_groups_for_head_shards(ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups + + +@dataclass(kw_only=True, frozen=True) +class Mamba2StateShape: + conv: tuple[int, int] + temporal: tuple[int, int, int] + + intermediate_size: int + conv_dim: int + ssm_state_size: int + num_heads: int + head_dim: int + state_size: int + conv_kernel: int + + @staticmethod + def create( + *, + tp_world_size: int, + intermediate_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + ) -> "Mamba2StateShape": + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + if n_groups % tp_world_size != 0: + extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size) + n_groups += extra_groups + # heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * state_size + + # contiguous along 'dim' axis + conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1 + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) + return Mamba2StateShape( + conv=conv_state_shape, + temporal=temporal_state_shape, + intermediate_size=intermediate_size, + conv_dim=conv_dim, + ssm_state_size=state_size, + num_heads=num_heads, + head_dim=head_dim, + state_size=state_size, + conv_kernel=conv_kernel, + ) + + +@dataclass(kw_only=True, frozen=True) +class Mamba2StateDType: + conv: torch.dtype + temporal: torch.dtype + + +CONV_DTYPE = torch.bfloat16 + + +def mamba2_state_dtype() -> Mamba2StateDType: + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] + return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype) + + +@dataclass(kw_only=True, frozen=True) +class Mamba2CacheParams: + shape: Mamba2StateShape + dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype) + layers: list[int] + + @property + def mamba_cache_per_req(self) -> int: + return ( + int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize + + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize + ) * len(self.layers) diff --git a/python/sglang/srt/configs/nemotron_h.py b/python/sglang/srt/configs/nemotron_h.py new file mode 100644 index 000000000..9e156f6a7 --- /dev/null +++ b/python/sglang/srt/configs/nemotron_h.py @@ -0,0 +1,286 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py + +"""NemotronH model configuration""" + +import regex as re +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape +from sglang.srt.layers.dp_attention import get_attention_tp_size + +logger = logging.get_logger(__name__) + +MAMBA = "M" +ATTENTION = "*" +MLP = "-" + + +class NemotronHConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`NemotronHModel`]. It is used to instantiate a NemotronH model according + to the specified arguments, defining the model architecture. Instantiating + a configuration with the defaults will yield a similar configuration to + that of the NemotronH-v0.1 model. + Args: + vocab_size (`int`, *optional*, defaults to 131072): + Vocabulary size of the NemotronH model. Defines the number of + different tokens that can be represented by the `inputs_ids` + passed when calling [`NemotronHModel`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be + tied. Note that this is only relevant if the model has an output + word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 21504): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 52): + Number of hidden layers in the Transformer encoder. + hybrid_override_pattern (`str`, *optional*, defaults to + `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`): + The pattern of the hybrid model. The pattern is a string of + characters where each character represents + M: Mamba2, *: Attention, -: MLP + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer encoder. + attention_head_dim (`int`, *optional*, defaults to 128): + Dimension of each attention head. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use + Multi Head Attention (MHA), if `num_key_value_heads=1` the model + will use Multi Query Attention (MQA) otherwise GQA is used. + mlp_hidden_act (`str`, *optional*, defaults to "relu2"): + The non-linear activation function in the MLP layers. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in MLP layers. + use_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the model. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + residual_in_fp32 (`bool`, *optional*, defaults to `False`): + Whether or not residuals should be in `float32`. If set to `False` + residuals will keep the same `dtype` as the rest of the model. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, + all logits will be calculated. If an integer value, only last + `num_logits_to_keep` logits will be calculated. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + sliding_window (`int`, *optional*, defaults to None): + Sliding window attention window size. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used + with. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + hidden_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the hidden states. + use_mamba_kernels (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use the fast mamba kernels. + These are available only if `mamba-ssm` and `causal-conv1d` + are installed, and the mamba modules are running on a CUDA device. + ssm_state_size (`int`, *optional*, defaults to 128): + The dimension of the mamba state space latents. + mamba_num_heads (`int`, *optional*, defaults to 128): + Number of heads in Mamba layers. + mamba_n_groups (`int`, *optional*, defaults to 8): + Number of groups in Mamba layers. + mamba_head_dim (`int`, *optional*, defaults to 64): + Dimension of each Mamba head. + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel. + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor used to determine the mamba intermediate size. + mamba_hidden_act (`str`, *optional*, defaults to "silu"): + The non-linear activation function in the Mamba layers. + mamba_dt_min (`float`, *optional*, defaults to 0.001): + Minimum value for the time step in Mamba. + mamba_dt_max (`float`, *optional*, defaults to 0.1): + Maximum value for the time step in Mamba. + mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))): + Limits for the time step in Mamba. + mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4): + Floor value for time step initialization in Mamba. + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias in the convolution layer of the mamba mixer + block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the input and output projections of the + mamba mixer block. + mamba_chunk_size (`int`, *optional*, defaults to 256): + Size of chunks for Mamba processing. + rescale_prenorm_residual (`bool`, *optional*, defaults to `True`): + Whether to rescale the pre-normalization residual connections. + """ + + model_type = "nemotron_h" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=21504, + num_hidden_layers=52, + hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-", + num_attention_heads=32, + head_dim=128, + num_key_value_heads=8, # nemo: num_query_groups + mlp_hidden_act="relu2", + attention_bias=False, + mlp_bias=False, + use_bias=False, + initializer_range=0.02, # nemo: init_method_std + layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon + residual_in_fp32=False, # Megatron Core default value + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + sliding_window=None, + max_position_embeddings=4096, + attention_dropout=0.0, + hidden_dropout=0.0, # * ADDED + use_mamba_kernels=True, + ssm_state_size=128, # mamba_state_size + mamba_num_heads=128, + mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads + mamba_head_dim=64, + mamba_d_conv=4, + mamba_expand=2, + mamba_hidden_act="silu", + mamba_dt_min=0.001, + mamba_dt_max=0.1, + mamba_dt_limit=(0.0, float("inf")), + mamba_dt_init_floor=1e-4, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_chunk_size=256, + rescale_prenorm_residual=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.tie_word_embeddings = tie_word_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.hybrid_override_pattern = hybrid_override_pattern + self.num_attention_heads = num_attention_heads + self.head_dim = head_dim + self.sliding_window = sliding_window + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.hidden_dropout = hidden_dropout + + # Validate hybrid_override_pattern + # M: Mamba2, *: Attention, -: MLP + assert len(self.hybrid_override_pattern) == self.num_hidden_layers, ( + "hybrid_override_pattern must have same length as " "num_hidden_layers" + ) + assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), ( + "hybrid_override_pattern must only contain characters " "'M', '*', or '-'" + ) + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.mlp_hidden_act = mlp_hidden_act + self.attention_bias = attention_bias + self.mlp_bias = mlp_bias + self.use_bias = use_bias + self.initializer_range = initializer_range + self.layer_norm_epsilon = layer_norm_epsilon + self.residual_in_fp32 = residual_in_fp32 + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.use_mamba_kernels = use_mamba_kernels + self.mamba_n_groups = mamba_n_groups + self.mamba_head_dim = mamba_head_dim + self.ssm_state_size = ssm_state_size + self.mamba_num_heads = mamba_num_heads + self.conv_kernel = mamba_d_conv + self.expand = mamba_expand + self.mamba_hidden_act = mamba_hidden_act + self.time_step_min = mamba_dt_min + self.time_step_max = mamba_dt_max + self.time_step_limit = mamba_dt_limit + self.time_step_floor = mamba_dt_init_floor + self.use_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + self.mamba_chunk_size = mamba_chunk_size + self.rescale_prenorm_residual = rescale_prenorm_residual + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def mamba_layer_ids(self): + return [ + i + for i in range(self.num_hidden_layers) + if self.hybrid_override_pattern[i] == MAMBA + ] + + @property + def full_attention_layer_ids(self): + return [ + i + for i in range(self.num_hidden_layers) + if self.hybrid_override_pattern[i] == ATTENTION + ] + + @property + def mamba2_cache_params(self) -> Mamba2CacheParams: + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=self.mamba_num_heads * self.mamba_head_dim, + n_groups=self.n_groups, + num_heads=self.mamba_num_heads, + head_dim=self.mamba_head_dim, + state_size=self.ssm_state_size, + conv_kernel=self.conv_kernel, + ) + + return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids) diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py index 099d14d41..62fd76f77 100644 --- a/python/sglang/srt/configs/qwen3_next.py +++ b/python/sglang/srt/configs/qwen3_next.py @@ -15,14 +15,12 @@ """Qwen3Hybrid model configuration""" import enum -import os -import numpy as np -import torch from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import rope_config_validation from transformers.utils import logging +from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape from sglang.srt.distributed.utils import divide from sglang.srt.layers.dp_attention import get_attention_tp_size @@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig): ] @property - def hybrid_gdn_params(self): - world_size = get_attention_tp_size() - conv_dim = ( - self.linear_key_head_dim * self.linear_num_key_heads * 2 - + self.linear_value_head_dim * self.linear_num_value_heads - ) - conv_state_shape = ( - divide(conv_dim, world_size), - self.linear_conv_kernel_dim - 1, + def mamba2_cache_params(self) -> Mamba2CacheParams: + shape = Mamba2StateShape.create( + tp_world_size=get_attention_tp_size(), + intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads, + n_groups=self.linear_num_key_heads, + num_heads=self.linear_num_value_heads, + head_dim=self.linear_value_head_dim, + state_size=self.linear_key_head_dim, + conv_kernel=self.linear_conv_kernel_dim, ) - temporal_state_shape = ( - divide(self.linear_num_value_heads, world_size), - self.linear_key_head_dim, - self.linear_value_head_dim, - ) - conv_dtype = torch.bfloat16 - dtype_map = { - "float32": torch.float32, - "bfloat16": torch.bfloat16, - } - ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] - mamba_layers = self.linear_layer_ids - return ( - conv_state_shape, - temporal_state_shape, - conv_dtype, - ssm_dtype, - mamba_layers, - ) - - @property - def mamba_cache_per_req(self): - conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = ( - self.hybrid_gdn_params - ) - mamba_layers_len = len(mamba_layers) - - return ( - int(np.prod(conv_state_shape)) * conv_dtype.itemsize - + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize - ) * mamba_layers_len + return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index 77d8e2eb6..2bf271c29 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,7 +1,14 @@ import logging +from typing import TYPE_CHECKING logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + # evade circular imports + from sglang.srt.layers.attention.base_attn_backend import AttentionBackend + from sglang.srt.model_executor.model_runner import ModelRunner + ATTENTION_BACKENDS = {} @@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner): return DualChunkFlashAttentionBackend(runner) -def attn_backend_wrapper(runner, full_attn_backend): +def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"): """ Wrapper for special models like hybrid GDN, so we don't need to change the code of the original attention backend. """ assert not ( - runner.is_hybrid_gdn and runner.use_mla_backend + runner.hybrid_gdn_config is not None and runner.use_mla_backend ), "hybrid_gdn can only be used with non-MLA models." - # wrap for hybrid GDN models - if runner.is_hybrid_gdn: + if cfg := runner.mambaish_config: + from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( + GDNAttnBackend, + HybridLinearAttnBackend, + Mamba2AttnBackend, + ) from sglang.srt.utils import is_blackwell, is_npu - if is_blackwell(): - assert ( - runner.server_args.attention_backend == "triton" - or runner.server_args.attention_backend == "trtllm_mha" - ), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend." - if is_npu(): - assert ( - runner.server_args.attention_backend == "ascend" - ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend." - logger.info(f"Using hybrid linear attention backend for hybrid GDN models.") - from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( - HybridLinearAttnBackend, - MambaAttnBackend, - ) - - linear_attn_backend = MambaAttnBackend(runner) - full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids + if runner.hybrid_gdn_config is not None: + if is_blackwell(): + assert ( + runner.server_args.attention_backend == "triton" + ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend." + if is_npu(): + assert ( + runner.server_args.attention_backend == "ascend" + ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend." + logger.info(f"Using hybrid linear attention backend for hybrid GDN models.") + linear_attn_backend = GDNAttnBackend(runner) + elif runner.mamba2_config is not None: + linear_attn_backend = Mamba2AttnBackend(runner) + else: + raise ValueError( + "Expected hybrid GDN or NemotronH models, but got unknown model." + ) + full_attn_layers = cfg.full_attention_layer_ids return HybridLinearAttnBackend( full_attn_backend, linear_attn_backend, full_attn_layers ) diff --git a/python/sglang/srt/layers/attention/fla/layernorm_gated.py b/python/sglang/srt/layers/attention/fla/layernorm_gated.py index 89482245b..50b7244c6 100644 --- a/python/sglang/srt/layers/attention/fla/layernorm_gated.py +++ b/python/sglang/srt/layers/attention/fla/layernorm_gated.py @@ -181,6 +181,45 @@ def _layer_norm_fwd( return out, mean, rstd +def rms_norm_gated( + *, + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" + + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, mean, rstd = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + return y.reshape(x_shape_og) + + class LayerNormFn(torch.autograd.Function): @staticmethod @@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function): norm_before_gate=True, is_rms_norm=False, ): - """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - - x_shape_og = x.shape - # reshape input data into 2D tensor - x = x.reshape(-1, x.shape[-1]) - if x.stride(-1) != 1: - x = x.contiguous() - if z is not None: - assert z.shape == x_shape_og - z = z.reshape(-1, z.shape[-1]) - if z.stride(-1) != 1: - z = z.contiguous() - weight = weight.contiguous() - if bias is not None: - bias = bias.contiguous() - y, mean, rstd = _layer_norm_fwd( - x, - weight, - bias, - eps, + return rms_norm_gated( + x=x, + weight=weight, + bias=bias, + eps=eps, z=z, group_size=group_size, norm_before_gate=norm_before_gate, is_rms_norm=is_rms_norm, ) - return y.reshape(x_shape_og) def layernorm_fn( @@ -238,14 +261,6 @@ def layernorm_fn( ) -def rmsnorm_fn( - x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True -): - return LayerNormFn.apply( - x, weight, bias, z, eps, group_size, norm_before_gate, True - ) - - class LayerNorm(torch.nn.Module): def __init__( @@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module): group_size=self.group_size, eps=self.eps, norm_before_gate=self.norm_before_gate, + is_rms_norm=False, ) @@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module): def forward(self, x, z=None): """If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))""" - return rmsnorm_fn( + return layernorm_fn( x, self.weight, self.bias, @@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module): eps=self.eps, group_size=self.group_size, norm_before_gate=self.norm_before_gate, + is_rms_norm=True, ) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index d405713b7..7f2e90255 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update, ) from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( + PAD_SLOT_ID, causal_conv1d_fn, causal_conv1d_update, ) +from sglang.srt.layers.attention.mamba.mamba import MambaMixer2 +from sglang.srt.layers.attention.mamba.mamba2_metadata import ( + ForwardMetadata, + Mamba2Metadata, +) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.models.qwen3_next import fused_gdn_gating +from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import is_cuda, is_npu @@ -47,18 +54,10 @@ elif is_npu(): causal_conv1d_update = causal_conv1d_update_npu -@dataclass -class ForwardMetadata: - query_start_loc: Optional[torch.Tensor] - mamba_cache_indices: torch.Tensor - - -class MambaAttnBackend(AttentionBackend): - """Attention backend using Mamba kernel.""" - +class MambaAttnBackendBase(AttentionBackend): def __init__(self, model_runner: ModelRunner): super().__init__() - self.pad_slot_id = -1 # Default pad slot id + self.pad_slot_id = PAD_SLOT_ID self.device = model_runner.device self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool self.forward_metadata: ForwardMetadata = None @@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend): self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None - def init_forward_metadata(self, forward_batch: ForwardBatch): + def _forward_metadata(self, forward_batch: ForwardBatch): bs = forward_batch.batch_size if forward_batch.forward_mode.is_decode_or_idle(): @@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend): mamba_cache_indices = self.req_to_token_pool.get_mamba_indices( forward_batch.req_pool_indices ) - self.forward_metadata = ForwardMetadata( + return ForwardMetadata( query_start_loc=query_start_loc, mamba_cache_indices=mamba_cache_indices, ) + def init_forward_metadata(self, forward_batch: ForwardBatch): + self.forward_metadata = self._forward_metadata(forward_batch) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + self.forward_metadata = self._capture_metadata( + bs, req_pool_indices, forward_mode + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + self.forward_metadata = self._replay_metadata( + bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu + ) + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): assert ( max_num_tokens % max_bs == 0 @@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend): device=self.device, ) - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - num_tokens: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor], - forward_mode: ForwardMode, - spec_info: Optional[SpecInput], + def _capture_metadata( + self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode ): if forward_mode.is_decode_or_idle(): self.query_start_loc_list[bs - 1].copy_( @@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend): raise ValueError(f"Invalid forward mode: {forward_mode=}") mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) - self.forward_metadata = ForwardMetadata( + return ForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], ) - def init_forward_metadata_replay_cuda_graph( + def _replay_metadata( self, bs: int, req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], @@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend): else: raise ValueError(f"Invalid forward mode: {forward_mode=}") - self.forward_metadata = ForwardMetadata( + return ForwardMetadata( query_start_loc=self.query_start_loc_list[bs - 1], mamba_cache_indices=self.state_indices_list[bs - 1], ) @@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend): def get_cuda_graph_seq_len_fill_value(self): return 1 # Mamba attn does not use seq lens to index kv cache + +class GDNAttnBackend(MambaAttnBackendBase): + """Attention backend using Mamba kernel.""" + def forward_decode( self, q: torch.Tensor, @@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend): dt_bias = kwargs["dt_bias"] layer_id = kwargs["layer_id"] - conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params( - layer_id - ) + layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id) + conv_states = layer_cache.conv + ssm_states = layer_cache.temporal query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices @@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend): query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices + mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id) + conv_states = mamba_cache_params.conv + ssm_states = mamba_cache_params.temporal if is_target_verify: - ( - conv_states, - ssm_states, - intermediate_state_cache, - intermediate_conv_window_cache, - ) = self.req_to_token_pool.get_mamba_params(layer_id) + assert isinstance(mamba_cache_params, MambaPool.SpeculativeState) + intermediate_state_cache = mamba_cache_params.intermediate_ssm + intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window has_initial_states = torch.ones( seq_len // forward_batch.spec_info.draft_token_num, dtype=torch.bool, @@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend): ) conv_states_to_use = conv_states.clone() else: - conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params( - layer_id - ) has_initial_states = forward_batch.extend_prefix_lens > 0 conv_states_to_use = conv_states @@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend): return core_attn_out +class Mamba2AttnBackend(MambaAttnBackendBase): + """Attention backend wrapper for Mamba2Mixer kernels.""" + + def __init__(self, model_runner: ModelRunner): + super().__init__(model_runner) + config = model_runner.mamba2_config + assert config is not None + self.mamba_chunk_size = config.mamba_chunk_size + + def init_forward_metadata(self, forward_batch: ForwardBatch): + metadata = self._forward_metadata(forward_batch) + self.forward_metadata = Mamba2Metadata.prepare_mixed( + metadata.query_start_loc, + metadata.mamba_cache_indices, + self.mamba_chunk_size, + forward_batch, + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + metadata = self._capture_metadata(bs, req_pool_indices, forward_mode) + self.forward_metadata = Mamba2Metadata.prepare_decode( + metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + metadata = self._replay_metadata( + bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu + ) + self.forward_metadata = Mamba2Metadata.prepare_decode( + metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens + ) + + def forward( + self, + mixer: MambaMixer2, + hidden_states: torch.Tensor, + output: torch.Tensor, + layer_id: int, + mup_vector: Optional[torch.Tensor] = None, + use_triton_causal_conv: bool = False, + ): + assert isinstance(self.forward_metadata, Mamba2Metadata) + layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id) + return mixer.forward( + hidden_states=hidden_states, + output=output, + layer_cache=layer_cache, + metadata=self.forward_metadata, + mup_vector=mup_vector, + use_triton_causal_conv=use_triton_causal_conv, + ) + + def forward_decode(self, *args, **kwargs): + raise NotImplementedError( + "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode" + ) + + def forward_extend(self, *args, **kwargs): + raise NotImplementedError( + "Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode" + ) + + class HybridLinearAttnBackend(AttentionBackend): - """Support different backends for prefill and decode.""" + """Manages a full and linear attention backend""" def __init__( self, full_attn_backend: AttentionBackend, - linear_attn_backend: AttentionBackend, + linear_attn_backend: MambaAttnBackendBase, full_attn_layers: list[int], ): self.full_attn_layers = full_attn_layers + self.full_attn_backend = full_attn_backend + self.linear_attn_backend = linear_attn_backend self.attn_backend_list = [full_attn_backend, linear_attn_backend] def init_forward_metadata(self, forward_batch: ForwardBatch): @@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend): ) def get_cuda_graph_seq_len_fill_value(self): - return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value() + return self.full_attn_backend.get_cuda_graph_seq_len_fill_value() def forward_decode( self, @@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend): ): layer_id = layer.layer_id if layer else kwargs["layer_id"] if layer_id in self.full_attn_layers: - return self.attn_backend_list[0].forward_decode( + return self.full_attn_backend.forward_decode( q, k, v, layer, forward_batch, save_kv_cache, **kwargs ) - return self.attn_backend_list[1].forward_decode( + return self.linear_attn_backend.forward_decode( q, k, v, layer, forward_batch, save_kv_cache, **kwargs ) @@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend): ): layer_id = layer.layer_id if layer else kwargs["layer_id"] if layer_id in self.full_attn_layers: - return self.attn_backend_list[0].forward_extend( + return self.full_attn_backend.forward_extend( q, k, v, layer, forward_batch, save_kv_cache, **kwargs ) - return self.attn_backend_list[1].forward_extend( + return self.linear_attn_backend.forward_extend( q, k, v, layer, forward_batch, save_kv_cache, **kwargs ) @@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend): def update_mamba_state_after_mtp_verify(self, accepted_length, model): request_number = accepted_length.shape[0] - state_indices_tensor = self.attn_backend_list[ - 1 - ].forward_metadata.mamba_cache_indices[:request_number] + state_indices_tensor = ( + self.linear_attn_backend.forward_metadata.mamba_cache_indices[ + :request_number + ] + ) - mamba_caches = self.attn_backend_list[ - 1 - ].req_to_token_pool.get_mamba_params_all_layers() + mamba_caches = ( + self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers() + ) - ( - conv_states, - ssm_states, - intermediate_state_cache, - intermediate_conv_window_cache, - ) = mamba_caches + conv_states = mamba_caches.conv + ssm_states = mamba_caches.temporal + intermediate_state_cache = mamba_caches.intermediate_ssm + intermediate_conv_window_cache = mamba_caches.intermediate_conv_window # SSM state updates (chunked to reduce peak memory) valid_mask = accepted_length > 0 diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d.py index d9f63641d..071a0ee6f 100644 --- a/python/sglang/srt/layers/attention/mamba/causal_conv1d.py +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d.py @@ -10,7 +10,7 @@ import torch from sgl_kernel import causal_conv1d_fwd from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel -PAD_SLOT_ID = -1 +from .causal_conv1d_triton import PAD_SLOT_ID def causal_conv1d_fn( diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py index 8c9d8bd7b..dbd9dac34 100644 --- a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py @@ -6,11 +6,11 @@ from typing import List, Optional, Union import numpy as np import torch - -PAD_SLOT_ID = -1 import triton import triton.language as tl +PAD_SLOT_ID = -1 + @triton.jit() def _causal_conv1d_fwd_kernel( # continuous batching @@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel( + (conv_state_batch_coord * stride_conv_state_seq) + conv_state_token_offset * stride_conv_state_tok + (idx_feats * stride_conv_state_dim)[None, :] - + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[ + :, None + ] ) # [BLOCK_M, BLOCK_N] mask = ( (conv_state_batch_coord < num_cache_lines) @@ -897,7 +899,10 @@ def causal_conv1d_update( stride_state_indices = ( conv_state_indices.stride(0) if conv_state_indices is not None else 0 ) - state_len = width - 1 + (seqlen - 1) # effective state_len needed + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) def grid(META): diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index b48ee694f..5d9fe23e3 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -1,23 +1,30 @@ -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple import torch import torch.nn as nn -from sglang.srt.configs.model_config import ModelConfig -from sglang.srt.custom_op import CustomOp +from sglang.srt.configs.mamba_utils import ( + Mamba2CacheParams, + extra_groups_for_head_shards, +) from sglang.srt.distributed import ( + divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - tensor_model_parallel_all_gather, - tensor_model_parallel_all_reduce, ) from sglang.srt.distributed.utils import divide -from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn from sglang.srt.layers.attention.mamba.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) -from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator +from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( + causal_conv1d_fn as causal_conv1d_fn_triton, +) +from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( + causal_conv1d_update as causal_conv1d_update_triton, +) +from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata +from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated from sglang.srt.layers.attention.mamba.ops import ( mamba_chunk_scan_combined, selective_state_update, @@ -28,7 +35,7 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.mem_cache.memory_pool import MambaPool from sglang.srt.model_loader.weight_utils import ( composed_weight_loader, sharded_weight_loader, @@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader( return loader -class Mixer2RMSNormGated(CustomOp): - - def __init__( - self, - full_hidden_size: int, - full_n_groups: int, - use_rms_norm: bool = True, - eps: float = 1e-6, - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tensor_model_parallel_rank() - self.full_hidden_size = full_hidden_size - self.group_size = full_hidden_size // full_n_groups - self.per_rank_hidden_size = full_hidden_size // self.tp_size - self.n_groups = full_hidden_size // self.group_size - - self.variance_epsilon = eps - self.use_rms_norm = use_rms_norm - if self.use_rms_norm: - # Register norm weight only if we're actually applying RMSNorm - self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) - set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) - else: - # Avoid checkpoint mismatch by skipping unused parameter - self.register_parameter("weight", None) - assert ( - self.full_hidden_size % self.tp_size == 0 - ), "Tensor parallel world size must divide hidden size." - - def forward_native( - self, - x: torch.Tensor, - gate: torch.Tensor, - ): - # Three tensor-parallel cases: - # 1. n_groups is 1 - # In this case we parallelize along the reduction dim. - # Each rank computes a local sum of squares followed by AllReduce - # 2. tp_size divides n_groups - # Each rank only reduces within its local group(s). - # No collective ops necessary. - # 3. The general case can be pretty complicated so we AllGather - # the input and then redundantly compute the RMSNorm. - input_dtype = x.dtype - x = x * nn.functional.silu(gate.to(torch.float32)) - if not self.use_rms_norm: - return x.to(input_dtype) - - if self.n_groups == 1: - if self.tp_size > 1: - # Compute local sum and then reduce to obtain global sum - local_sums = x.pow(2).sum(dim=-1, keepdim=True) - global_sums = tensor_model_parallel_all_reduce(local_sums) - # Calculate the variance - count = self.tp_size * x.shape[-1] - variance = global_sums / count - - else: - variance = x.pow(2).mean(-1, keepdim=True) - x = x * torch.rsqrt(variance + self.variance_epsilon) - else: - redundant_tp: bool = self.n_groups % self.tp_size != 0 - if redundant_tp: - # To handle the general case, redundantly apply the variance - x = tensor_model_parallel_all_gather(x, -1) - - *prefix_dims, hidden_dim = x.shape - group_count = hidden_dim // self.group_size - x_grouped = x.view(*prefix_dims, group_count, self.group_size) - variance = x_grouped.pow(2).mean(-1, keepdim=True) - x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) - x = x_grouped.view(*prefix_dims, hidden_dim) - - if redundant_tp: - start = self.per_rank_hidden_size * self.tp_rank - end = start + self.per_rank_hidden_size - x = x[..., start:end] - - return self.weight * x.to(input_dtype) - - def forward_cuda( - self, - x: torch.Tensor, - gate: torch.Tensor, - ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: - input_dtype = x.dtype - if not self.use_rms_norm: - # Keep gate in float32 for numerical stability during silu - return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype) - - if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1: - return self.forward_native(x, gate) - - return layernorm_fn( - x, - self.weight.data, - bias=None, - z=gate, - eps=self.variance_epsilon, - norm_before_gate=False, - ) - - class MambaMixer2(torch.nn.Module): """ Compute ∆, A, B, C, and D the state space parameters and compute @@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module): def __init__( self, + cache_params: Mamba2CacheParams, hidden_size: int, - ssm_state_size: int, - conv_kernel_size: int, - intermediate_size: int, use_conv_bias: bool, use_bias: bool, - chunk_size: int, - layer_id: int, n_groups: int = 1, - num_heads: int = 128, - head_dim: int = 64, rms_norm_eps: float = 1e-5, activation: str = "silu", use_rms_norm: bool = True, - model_config: Optional[ModelConfig] = None, - # cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): @@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module): self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() + self.num_heads = num_heads = cache_params.shape.num_heads + self.head_dim = cache_params.shape.head_dim + assert ( num_heads % self.tp_size == 0 ), "Tensor parallel world size must divide num heads." @@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module): "then num_groups must equal 1." ) - self.ssm_state_size = ssm_state_size - self.conv_kernel_size = conv_kernel_size + assert ( + (n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None + ), ( + "Tensor parallel currently supported for quantized models only " + "if tensor parallel world size divides num groups." + ) + + self.ssm_state_size = cache_params.shape.ssm_state_size self.activation = activation - self.layer_id = layer_id - - self.intermediate_size = intermediate_size - self.head_dim = head_dim - self.num_heads = num_heads - self.chunk_size = chunk_size + conv_kernel_size = cache_params.shape.conv_kernel + self.intermediate_size = intermediate_size = ( + cache_params.shape.intermediate_size + ) self.n_groups = n_groups if n_groups % self.tp_size != 0: # - for TP we shard conv_dim by sharding on n_groups, # - but if n_groups cannot divide tp_size, we need to # extend some extra groups - groups = MambaStateShapeCalculator.extra_groups_for_head_shards( - n_groups, self.tp_size - ) + groups = extra_groups_for_head_shards(n_groups, self.tp_size) self.n_groups = n_groups + groups - self.groups_ssm_state_size = self.n_groups * self.ssm_state_size - self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size + self.conv_dim = cache_params.shape.conv_dim - self.conv1d = MergedColumnParallelLinear( - input_size=conv_kernel_size, - output_sizes=[ - intermediate_size, - self.groups_ssm_state_size, - self.groups_ssm_state_size, - ], - bias=use_conv_bias, - quant_config=None, - prefix=f"{prefix}.conv1d", - ) + if n_groups % self.tp_size == 0: + self.conv1d = MergedColumnParallelLinear( + input_size=conv_kernel_size, + output_sizes=[ + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + ], + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) - self.in_proj = MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[ - intermediate_size, - intermediate_size, - self.groups_ssm_state_size, - self.groups_ssm_state_size, - self.num_heads, - ], - bias=use_bias, - prefix=f"{prefix}.in_proj", - ) - if n_groups % self.tp_size != 0: + self.in_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + intermediate_size, + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + self.num_heads, + ], + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + else: # This is the n_groups == 1 case, # where we need to duplicate groups if TP>1. + self.conv1d = ColumnParallelLinear( + input_size=conv_kernel_size, + output_size=self.conv_dim, + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) + + self.in_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size + self.conv_dim + self.num_heads, + bias=use_bias, + quant_config=quant_config, + prefix=f"{prefix}.in_proj", + ) + # - because in_proj is a concatenation of 3 weights, we # need to interleave them before sharding # - use the custom weight loader mamba_v2_sharded_weight_loader @@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module): intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps ) - # The tuple is (conv_state, ssm_state) - self.kv_cache = (torch.tensor([]), torch.tensor([])) - - self.model_config = model_config self.prefix = prefix - def forward_native( - self, - hidden_states: torch.Tensor, - output: torch.Tensor, - mup_vector: Optional[torch.Tensor] = None, - ): - pass - def forward( self, + *, hidden_states: torch.Tensor, output: torch.Tensor, - forward_batch: ForwardBatch, + layer_cache: MambaPool.State, + metadata: Mamba2Metadata, mup_vector: Optional[torch.Tensor] = None, + use_triton_causal_conv: bool = False, ): - # attn_backend_list[-1] gives access to MambaAttnBackend - mamba_backend = forward_batch.attn_backend.attn_backend_list[-1] - attn_metadata = mamba_backend.forward_metadata - state_indices_tensor = attn_metadata.mamba_cache_indices - chunk_size = self.chunk_size + # metadata contains metadata necessary for the mamba2 triton + # kernels to operate in continuous batching and in chunked prefill + # modes; they are computed at top-level model forward since they + # stay the same and reused for all mamba layers in the same iteration + state_indices_tensor = metadata.mamba_cache_indices + conv_state = layer_cache.conv + ssm_state = layer_cache.temporal - conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params( - self.layer_id - ) - - assert ( - ssm_state.size(1) == self.ssm_state_size - ), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}" - - query_start_loc = attn_metadata.query_start_loc - - chunk_size = self.chunk_size - - # TODO: properly support this - prep_initial_states = False + query_start_loc = metadata.query_start_loc # 1. Gated MLP's linear projection projected_states, _ = self.in_proj(hidden_states) @@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module): dim=-1, ) + num_prefills = metadata.num_prefills # request count + num_decodes = metadata.num_decodes # token count (=request) + num_prefill_tokens = metadata.num_prefill_tokens # token count + has_prefill = num_prefills > 0 + has_decode = num_decodes > 0 + num_actual_tokens = num_prefill_tokens + num_decodes + assert num_actual_tokens == projected_states.shape[0] + + # NOTE: V0 put prefill before decode + # Separate prefill and decode by splitting varlen input + # Split along token dimension + hidden_states_B_C_p, hidden_states_B_C_d = torch.split( + hidden_states_B_C, + [num_prefill_tokens, num_decodes], + dim=0, + ) + dt_p, dt_d = torch.split( + dt, + [num_prefill_tokens, num_decodes], + dim=0, + ) + # Split along batch dimension + state_indices_tensor_p, state_indices_tensor_d = torch.split( + state_indices_tensor, + [num_prefills, num_decodes], + dim=0, + ) + query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None + + # Preallocate output tensor to avoid memcpy cost for merging prefill + # and decode outputs + preallocated_ssm_out = torch.empty( [ projected_states.shape[0], @@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module): dtype=hidden_states.dtype, device=hidden_states.device, ) + preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split( + preallocated_ssm_out, + [num_prefill_tokens, num_decodes], + dim=0, + ) # Process prefill requests - if forward_batch.forward_mode.is_extend(): + if has_prefill: + mixed_metadata = metadata.mixed_metadata + assert mixed_metadata is not None # 2. Convolution sequence transformation # - "cache_indices" updates the conv_state cache in positions # pointed to by "state_indices_tensor" - num_prefill_tokens = forward_batch.extend_num_tokens or 0 - has_initial_states = forward_batch.extend_prefix_lens > 0 - cache_indices = attn_metadata.mamba_cache_indices - - x = hidden_states_B_C.transpose( + has_initial_states_p = mixed_metadata.has_initial_states + prep_initial_states = mixed_metadata.prep_initial_states + cache_indices = state_indices_tensor_p + x = hidden_states_B_C_p.transpose( 0, 1 ) # this is the form that causal-conv see - hidden_states_B_C = causal_conv1d_fn( + ccfn = ( + causal_conv1d_fn + if not use_triton_causal_conv + else causal_conv1d_fn_triton + ) + hidden_states_B_C_p = ccfn( x, conv_weights, self.conv1d.bias, activation=self.activation, conv_states=conv_state, - has_initial_state=has_initial_states, + has_initial_state=has_initial_states_p, cache_indices=cache_indices, - query_start_loc=query_start_loc, - seq_lens_cpu=forward_batch.extend_seq_lens_cpu, - ).transpose(0, 1) + query_start_loc=query_start_loc_p, + seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu, + ).transpose(0, 1)[:num_prefill_tokens] - hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C) + hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p) # 3. State Space Model sequence transformation initial_states = None - - if has_initial_states is not None and prep_initial_states: + if has_initial_states_p is not None and prep_initial_states: initial_states = torch.where( - has_initial_states[:, None, None, None], - ssm_state[state_indices_tensor], + has_initial_states_p[:, None, None, None], + ssm_state[state_indices_tensor_p], 0, ) # NOTE: final output is an in-place update of out tensor varlen_state = mamba_chunk_scan_combined( - hidden_states.view( + hidden_states_p.view( 1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim ), - dt.unsqueeze(0), + dt_p.unsqueeze(0), self.A, - B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), - C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), - chunk_size=chunk_size, + B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), + C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), + chunk_size=mixed_metadata.chunk_size, D=self.D, z=None, dt_bias=self.dt_bias, - cu_seqlens=query_start_loc, + seq_idx=mixed_metadata.seq_idx, + chunk_indices=mixed_metadata.chunk_indices, + chunk_offsets=mixed_metadata.chunk_offsets, + cu_seqlens=query_start_loc_p, initial_states=initial_states, return_varlen_states=True, return_final_states=False, dt_softplus=True, dt_limit=(0.0, float("inf")), - out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim), + out=preallocated_ssm_out_p.view( + 1, num_prefill_tokens, -1, self.head_dim + ), state_dtype=ssm_state.dtype, ) # update ssm states # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor - ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1) - elif forward_batch.forward_mode.is_decode(): - num_decodes = len(query_start_loc) - 1 + ssm_state[state_indices_tensor_p] = varlen_state + + # Process decode requests + if has_decode: # 2. Convolution sequence transformation - hidden_states_B_C = causal_conv1d_update( - hidden_states_B_C, + ccu = ( + causal_conv1d_update + if not use_triton_causal_conv + else causal_conv1d_update_triton + ) + hidden_states_B_C_d = ccu( + hidden_states_B_C_d, conv_state, conv_weights, self.conv1d.bias, self.activation, - conv_state_indices=state_indices_tensor, + conv_state_indices=state_indices_tensor_d, ) - hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C) + hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d) # 3. State Space Model sequence transformation n_groups = self.n_groups // self.tp_size - A = ( + A_d = ( self.A[:, None, ...][:, :, None] .expand(-1, self.head_dim, self.ssm_state_size) .to(dtype=torch.float32) ) - dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim) dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) - D = self.D[:, None, ...].expand(-1, self.head_dim) - B = B.view(-1, n_groups, B.shape[1] // n_groups) - C = C.view(-1, n_groups, C.shape[1] // n_groups) - hidden_states = hidden_states.view( + D_d = self.D[:, None, ...].expand(-1, self.head_dim) + B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups) + C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups) + hidden_states_d = hidden_states_d.view( -1, self.num_heads // self.tp_size, self.head_dim ) # - the hidden is reshaped into (bs, num_heads, head_dim) - # - mamba_cache_params.ssm_state's slots will be selected + # - layer_state.ssm_state's slots will be selected # using state_indices_tensor_d # NOTE: final output is an in-place update of out tensor selective_state_update( - ssm_state.permute(0, 3, 2, 1), - hidden_states, - dt, - A, - B, - C, - D, + ssm_state, + hidden_states_d, + dt_d, + A_d, + B_d, + C_d, + D_d, z=None, dt_bias=dt_bias, dt_softplus=True, - state_batch_indices=state_indices_tensor, - out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim), + state_batch_indices=state_indices_tensor_d, + out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), ) - elif forward_batch.forward_mode.is_idle(): - preallocated_ssm_out = preallocated_ssm_out # 4. gated MLP # GatedRMSNorm internally applying SiLU to the gate # SiLU is applied internally before normalization, unlike standard # norm usage - hidden_states = self.norm(preallocated_ssm_out, gate) + hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens]) # 5. Final linear projection - output[:], _ = self.out_proj(hidden_states) + output[:num_actual_tokens], _ = self.out_proj(hidden_states) @property def mamba_type(self) -> str: diff --git a/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py new file mode 100644 index 000000000..75f33cbba --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/mamba2_metadata.py @@ -0,0 +1,211 @@ +# Copyright 2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py + + +import math +from dataclasses import dataclass + +import torch + +from sglang.srt.model_executor.forward_batch_info import ForwardBatch + + +@dataclass(kw_only=True) +class ForwardMetadata: + query_start_loc: torch.Tensor + mamba_cache_indices: torch.Tensor + + +@dataclass(kw_only=True) +class Mamba2Metadata(ForwardMetadata): + """stable metadata across all mamba2 layers in the forward pass""" + + num_prefills: int + num_prefill_tokens: int + num_decodes: int + + @dataclass(kw_only=True, frozen=True) + class MixedMetadata: + has_initial_states: torch.Tensor + prep_initial_states: bool + + chunk_size: int + seq_idx: torch.Tensor + chunk_indices: torch.Tensor + chunk_offsets: torch.Tensor + + extend_seq_lens_cpu: list[int] + + mixed_metadata: MixedMetadata | None = None + """`mixed_metadata` is used for extend/mixed requests""" + + @staticmethod + def _query_start_loc_to_chunk_indices_offsets( + query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + query_start_loc (torch.Tensor): 1D tensor of cumulative sequence + lengths, shape (num_seqs + 1,). + The first element should be 0. Each entry represents the starting + index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk + (number of tokens per chunk). + total_seqlens (int): The total number of tokens in the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices + indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets + indicating the starting index of each logical chunk within + its physical chunk. + + This function computes the chunk indices and offsets for the given + query_start_loc and chunk_size. Both are tensors of integers with length N, + where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same + sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence + boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states + (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each + logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the + logical chunk in the physical chunk. + + Example: + query_start_loc = [0, 5, 10] + chunk_size = 8 + total_seqlens = 10 + -> chunk_indices = [0, 0, 1] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical + chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk + and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk + and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk + and contains the remaining 2 tokens from the second sequence + """ + + cu_seqlens = query_start_loc[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = ( + math.ceil(total_seqlens / chunk_size) + + (cu_seqlens[:-1] % chunk_size > 0).sum() + ) + chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device) + chunk_offsets = torch.zeros( + (N,), dtype=torch.int, device=query_start_loc.device + ) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += s % chunk_size > 0 + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0) + + # adjust indices and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + @staticmethod + def prepare_decode( + query_start_loc: torch.Tensor, + mamba_cache_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> "Mamba2Metadata": + """This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0""" + return Mamba2Metadata( + query_start_loc=query_start_loc, + mamba_cache_indices=mamba_cache_indices, + num_decodes=len(seq_lens), + num_prefills=0, + num_prefill_tokens=0, + ) + + @classmethod + def prepare_mixed( + cls, + query_start_loc: torch.Tensor, + mamba_cache_indices: torch.Tensor, + chunk_size: int, + forward_batch: ForwardBatch, + ) -> "Mamba2Metadata": + """This path cannot run with CUDA graph, as it contains extend requests.""" + if forward_batch.extend_num_tokens is None: + return cls.prepare_decode( + query_start_loc, mamba_cache_indices, forward_batch.seq_lens + ) + num_prefills = len(forward_batch.extend_seq_lens) + num_prefill_tokens = forward_batch.extend_num_tokens + num_decodes = len(forward_batch.seq_lens) - num_prefills + context_lens_tensor = forward_batch.extend_prefix_lens + assert context_lens_tensor is not None + # precompute flag to avoid device syncs later + has_initial_states = context_lens_tensor > 0 + prep_initial_states = torch.any(has_initial_states[:num_prefills]).item() + + query_start_loc = query_start_loc[: num_prefills + 1] + seq_idx = torch.repeat_interleave( + torch.arange( + num_prefills, dtype=torch.int32, device=query_start_loc.device + ), + query_start_loc.diff(), + output_size=num_prefill_tokens, + ) + seq_idx.unsqueeze_(0) + + # We compute metadata for chunked prefill once at the top level model + # forward and reuse them in mamba layers. If not needed, they will be + # ignored inside mamba kernels. + chunk_offsets, chunk_indices = None, None + if prep_initial_states: + chunk_indices, chunk_offsets = ( + cls._query_start_loc_to_chunk_indices_offsets( + query_start_loc, chunk_size, num_prefill_tokens + ) + ) + + return Mamba2Metadata( + query_start_loc=query_start_loc, + mamba_cache_indices=mamba_cache_indices, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + num_decodes=num_decodes, + mixed_metadata=cls.MixedMetadata( + has_initial_states=has_initial_states, + prep_initial_states=prep_initial_states, + chunk_size=chunk_size, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + ), + ) diff --git a/python/sglang/srt/layers/attention/mamba/mamba_utils.py b/python/sglang/srt/layers/attention/mamba/mamba_utils.py deleted file mode 100644 index 7672934be..000000000 --- a/python/sglang/srt/layers/attention/mamba/mamba_utils.py +++ /dev/null @@ -1,81 +0,0 @@ -# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py -from sglang.srt.distributed.utils import divide - - -class MambaStateShapeCalculator: - - @classmethod - def linear_attention_state_shape( - cls, - num_heads: int, - tp_size: int, - head_dim: int, - ) -> tuple[tuple[int, int, int], ...]: - - state_shape = (num_heads // tp_size, head_dim, head_dim) - return (state_shape,) - - @classmethod - def mamba1_state_shape( - cls, - tp_world_size: int, - intermediate_size: int, - state_size: int, - conv_kernel: int, - ) -> tuple[tuple[int, int], tuple[int, int]]: - conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) - - temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) - - conv_state_shape = conv_state_shape[1], conv_state_shape[0] - - return conv_state_shape, temporal_state_shape - - @classmethod - def mamba2_state_shape( - cls, - tp_world_size: int, - intermediate_size: int, - n_groups: int, - num_heads: int, - head_dim: int, - state_size: int, - conv_kernel: int, - ) -> tuple[tuple[int, int], tuple[int, int, int]]: - # if n_groups is not divisible by world_size, need to extend the shards - # to ensure all groups needed by a head is sharded along with it - n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size) - # heads and n_groups are TP-ed - conv_dim = intermediate_size + 2 * n_groups * state_size - - # contiguous along 'dim' axis - conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) - - # These are not TP-ed as they depend on A, dt_bias, D - # - they are typically small - # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) - temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) - return conv_state_shape, temporal_state_shape - - @classmethod - def short_conv_state_shape( - cls, - tp_world_size: int, - intermediate_size: int, - conv_kernel: int, - ) -> tuple[tuple[int, int]]: - conv_dim = divide(intermediate_size, tp_world_size) - conv_state_shape = (conv_kernel - 1, conv_dim) - return (conv_state_shape,) - - @classmethod - def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): - """Compute the increase in group numbers to account for - replication in order to accompany the head shards.""" - - # in the case ngoups % tp_size == 0, this will be zero - if ngroups % tp_size == 0: - return 0 - - # for n_groups == 1, this is exactly tp_size - n_groups - return tp_size - ngroups diff --git a/python/sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py b/python/sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py new file mode 100644 index 000000000..271394c8e --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py @@ -0,0 +1,120 @@ +from typing import Union + +import torch + +from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated +from sglang.srt.model_loader.weight_utils import sharded_weight_loader +from sglang.srt.utils.common import set_weight_attrs + + +class Mixer2RMSNormGated(CustomOp): + def __init__( + self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.use_rms_norm = use_rms_norm + if self.use_rms_norm: + # Register norm weight only if we're actually applying RMSNorm + self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) + else: + # Avoid checkpoint mismatch by skipping unused parameter + self.register_parameter("weight", None) + assert ( + self.full_hidden_size % self.tp_size == 0 + ), "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * torch.nn.functional.silu(gate.to(torch.float32)) + if not self.use_rms_norm: + return x.to(input_dtype) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = global_sums / count + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + if not self.use_rms_norm: + # Keep gate in float32 for numerical stability during silu + return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype) + + if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1: + return self.forward_native(x, gate) + + return rms_norm_gated( + x=x, + weight=self.weight.data, + bias=None, + z=gate, + eps=self.variance_epsilon, + norm_before_gate=False, + is_rms_norm=True, + ) diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py index e618920ce..667d34afa 100644 --- a/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py @@ -15,56 +15,6 @@ import triton import triton.language as tl -# @triton.autotune( -# configs=[ -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, -# num_stages=3, -# num_warps=8, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=2, -# ), -# ], -# key=["chunk_size", "K", "IS_CAUSAL"], -# ) @triton.jit def _bmm_chunk_fwd_kernel( # Pointers to matrices diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py index b44f12089..52b197139 100644 --- a/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py @@ -16,66 +16,6 @@ from packaging import version TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") -# @triton.autotune( -# configs=[ -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, -# num_stages=3, -# num_warps=8, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=2, -# ), -# ], -# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], -# ) @triton.jit def _chunk_scan_fwd_kernel( # Pointers to matrices diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py index fc3946763..2dd583800 100644 --- a/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py @@ -17,17 +17,6 @@ import triton.language as tl from .mamba_ssm import softplus -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_SIZE_H": 2}), -# triton.Config({"BLOCK_SIZE_H": 4}), -# triton.Config({"BLOCK_SIZE_H": 8}), -# triton.Config({"BLOCK_SIZE_H": 16}), -# triton.Config({"BLOCK_SIZE_H": 32}), -# triton.Config({"BLOCK_SIZE_H": 64}), -# ], -# key=["chunk_size", "nheads"], -# ) @triton.jit def _chunk_cumsum_fwd_kernel( # Pointers to matrices @@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel( ) -# @triton.autotune( -# configs=[ -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, -# num_stages=3, -# num_warps=8, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=2, -# ), -# ], -# key=["hdim", "dstate", "chunk_size"], -# ) @triton.jit def _chunk_state_fwd_kernel( # Pointers to matrices @@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel( tl.store(states_ptrs, states, mask=c_mask) -# @triton.autotune( -# configs=[ -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, -# num_stages=3, -# num_warps=8, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=4, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=5, -# num_warps=2, -# ), -# triton.Config( -# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, -# num_stages=4, -# num_warps=2, -# ), -# ], -# key=["hdim", "dstate", "chunk_size"], -# ) @triton.jit def _chunk_state_varlen_kernel( # Pointers to matrices diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py index f0a8e0f6b..5e8c32385 100644 --- a/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py @@ -13,17 +13,6 @@ import triton import triton.language as tl -# @triton.autotune( -# configs=[ -# triton.Config({"BLOCK_SIZE": 64}), -# triton.Config({"BLOCK_SIZE": 128}), -# triton.Config({"BLOCK_SIZE": 256}), -# triton.Config({"BLOCK_SIZE": 512}), -# triton.Config({"BLOCK_SIZE": 1024}), -# triton.Config({"BLOCK_SIZE": 2048}), -# ], -# key=["dim"], -# ) @triton.jit def _state_passing_fwd_kernel( # Pointers to matrices diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 70e99c31f..1ef9274e5 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend): self.num_kv_head = model_runner.model_config.get_num_kv_heads( get_attention_tp_size() ) - if model_runner.is_hybrid_gdn: + if model_runner.hybrid_gdn_config is not None: # For hybrid linear models, layer_id = 0 may not be full attention self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim() else: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index eedb28e79..91c3c01b8 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1770,7 +1770,7 @@ class Scheduler( chunked_req_to_exclude.add(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) # chunked request keeps its rid but will get a new req_pool_idx - if self.tp_worker.worker.model_runner.is_hybrid_gdn: + if self.tp_worker.worker.model_runner.mambaish_config is not None: self.req_to_token_pool.free( self.chunked_req.req_pool_idx, free_mamba_cache=False ) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 69d07d41b..b577646a0 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -15,6 +15,9 @@ limitations under the License. from __future__ import annotations +from dataclasses import dataclass + +from sglang.srt.configs.mamba_utils import Mamba2CacheParams from sglang.srt.layers.attention.nsa import index_buf_accessor from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter @@ -109,17 +112,38 @@ class ReqToTokenPool: class MambaPool: + @dataclass(frozen=True, kw_only=True) + class State: + conv: torch.Tensor + temporal: torch.Tensor + + def at_layer_idx(self, layer: int): + return type(self)(**{k: v[layer] for k, v in vars(self).items()}) + + def mem_usage_bytes(self): + return sum(get_tensor_size_bytes(t) for t in vars(self).values()) + + @dataclass(frozen=True, kw_only=True) + class SpeculativeState(State): + intermediate_ssm: torch.Tensor + intermediate_conv_window: torch.Tensor + def __init__( self, + *, size: int, - conv_dtype: torch.dtype, - ssm_dtype: torch.dtype, - num_mamba_layers: int, - conv_state_shape: Tuple[int, int], - temporal_state_shape: Tuple[int, int], + cache_params: "Mamba2CacheParams", device: str, speculative_num_draft_tokens: Optional[int] = None, ): + conv_state_shape = cache_params.shape.conv + temporal_state_shape = cache_params.shape.temporal + conv_dtype = cache_params.dtype.conv + ssm_dtype = cache_params.dtype.temporal + num_mamba_layers = len(cache_params.layers) + + # assume conv_state = (dim, state_len) + assert conv_state_shape[0] > conv_state_shape[1] conv_state = torch.zeros( size=(num_mamba_layers, size + 1) + conv_state_shape, dtype=conv_dtype, @@ -158,11 +182,11 @@ class MambaPool: dtype=conv_dtype, device="cuda", ) - self.mamba_cache = ( - conv_state, - temporal_state, - intermediate_ssm_state_cache, - intermediate_conv_window_cache, + self.mamba_cache = self.SpeculativeState( + conv=conv_state, + temporal=temporal_state, + intermediate_ssm=intermediate_ssm_state_cache, + intermediate_conv_window=intermediate_conv_window_cache, ) logger.info( f"Mamba Cache is allocated. " @@ -172,7 +196,7 @@ class MambaPool: f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB " ) else: - self.mamba_cache = (conv_state, temporal_state) + self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state) logger.info( f"Mamba Cache is allocated. " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " @@ -180,16 +204,14 @@ class MambaPool: ) self.size = size self.free_slots = list(range(size)) - self.mem_usage = self.get_mamba_size() / GB + self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB - def get_mamba_params_all_layers(self): - return [self.mamba_cache[i] for i in range(len(self.mamba_cache))] + def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: + assert isinstance(self.mamba_cache, self.SpeculativeState) + return self.mamba_cache - def get_mamba_params(self, layer_id: int): - return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))] - - def get_mamba_size(self): - return sum(get_tensor_size_bytes(t) for t in self.mamba_cache) + def mamba2_layer_cache(self, layer_id: int): + return self.mamba_cache.at_layer_idx(layer_id) def available_size(self): return len(self.free_slots) @@ -208,7 +230,9 @@ class MambaPool: self.free_slots.append(free_index) else: self.free_slots.extend(free_index) - self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0 + self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[ + :, free_index + ] = 0 def clear(self): self.free_slots = list(range(self.size)) @@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool): def __init__( self, + *, size: int, max_context_len: int, device: str, enable_memory_saver: bool, - conv_dtype: torch.dtype, - ssm_dtype: torch.dtype, - mamba_layers: List[int], - conv_state_shape: Tuple[int, int], - temporal_state_shape: Tuple[int, int], - speculative_num_draft_tokens: int, + cache_params: "Mamba2CacheParams", + speculative_num_draft_tokens: int = None, ): super().__init__( size=size, @@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool): ) self.mamba_pool = MambaPool( - size, - conv_dtype, - ssm_dtype, - len(mamba_layers), - conv_state_shape, - temporal_state_shape, - device, - speculative_num_draft_tokens, + size=size, + cache_params=cache_params, + device=device, + speculative_num_draft_tokens=speculative_num_draft_tokens, ) - self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)} + self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} self.device = device self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros( @@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool): def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor: return self.req_index_to_mamba_index_mapping[req_indices] - def get_mamba_params(self, layer_id: int): + def mamba2_layer_cache(self, layer_id: int): assert layer_id in self.mamba_map - return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id]) + return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id]) - def get_mamba_params_all_layers(self): - return self.mamba_pool.get_mamba_params_all_layers() + def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState: + return self.mamba_pool.get_speculative_mamba2_params_all_layers() # For chunk prefill, we can not free mamba cache, we need use it in the future def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b3d2d1e67..73e6ccc7f 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist +from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import ( @@ -354,8 +355,9 @@ class ModelRunner: if architectures and not any("Llama4" in arch for arch in architectures): self.is_hybrid = self.model_config.is_hybrid = True - if self.is_hybrid_gdn: - logger.warning("Hybrid GDN model detected, disable radix cache") + if config := self.mambaish_config: + class_name = config.__class__.__name__ + logger.warning(f"{class_name} model detected, disable radix cache") self.server_args.disable_radix_cache = True if self.server_args.max_mamba_cache_size is None: if self.server_args.max_running_requests is not None: @@ -364,6 +366,7 @@ class ModelRunner: ) else: self.server_args.max_mamba_cache_size = 512 + if self.hybrid_gdn_config is not None: self.server_args.max_mamba_cache_size = ( self.server_args.max_mamba_cache_size // ( @@ -1267,8 +1270,8 @@ class ModelRunner: "num_nextn_predict_layers", self.num_effective_layers, ) - elif self.is_hybrid_gdn: - num_layers = len(self.model_config.hf_config.full_attention_layer_ids) + elif config := self.mambaish_config: + num_layers = len(config.full_attention_layer_ids) else: num_layers = self.num_effective_layers if self.use_mla_backend: @@ -1288,22 +1291,32 @@ class ModelRunner: rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) - if self.is_hybrid_gdn: + if config := self.mambaish_config: rest_memory -= ( self.server_args.max_mamba_cache_size - * self.model_config.hf_config.mamba_cache_per_req + * config.mamba2_cache_params.mamba_cache_per_req / (1 << 30) ) max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token @property - def is_hybrid_gdn(self): - return self.model_config.hf_config.architectures[0] in [ - "Qwen3NextForCausalLM", - "Qwen3NextForCausalLMMTP", - "FalconH1ForCausalLM", - ] + def hybrid_gdn_config(self): + config = self.model_config.hf_config + if isinstance(config, Qwen3NextConfig): + return config + return None + + @property + def mamba2_config(self): + config = self.model_config.hf_config + if isinstance(config, FalconH1Config | NemotronHConfig): + return config + return None + + @property + def mambaish_config(self): + return self.mamba2_config or self.hybrid_gdn_config def set_num_token_hybrid(self): if ( @@ -1438,7 +1451,7 @@ class ModelRunner: ), 4096, ) - if self.is_hybrid_gdn: + if self.mambaish_config is not None: max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): @@ -1519,26 +1532,14 @@ class ModelRunner: enable_memory_saver=self.server_args.enable_memory_saver, pre_alloc_size=pre_alloc_size, ) - elif self.is_hybrid_gdn: - config = self.model_config.hf_config - ( - conv_state_shape, - temporal_state_shape, - conv_dtype, - ssm_dtype, - mamba_layers, - ) = config.hybrid_gdn_params + elif config := self.mambaish_config: self.req_to_token_pool = HybridReqToTokenPool( size=max_num_reqs, max_context_len=self.model_config.context_len + extra_max_context_len, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, - conv_state_shape=conv_state_shape, - temporal_state_shape=temporal_state_shape, - conv_dtype=conv_dtype, - ssm_dtype=ssm_dtype, - mamba_layers=mamba_layers, + cache_params=config.mamba2_cache_params, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, ) else: @@ -1640,7 +1641,7 @@ class ModelRunner: enable_kvcache_transpose=False, device=self.device, ) - elif self.is_hybrid_gdn: + elif config := self.mambaish_config: self.token_to_kv_pool = HybridLinearKVPool( page_size=self.page_size, size=self.max_total_num_tokens, @@ -1651,9 +1652,7 @@ class ModelRunner: head_dim=self.model_config.head_dim, # if draft worker, we only need 1 attention layer's kv pool full_attention_layer_ids=( - [0] - if self.is_draft_worker - else self.model_config.hf_config.full_attention_layer_ids + [0] if self.is_draft_worker else config.full_attention_layer_ids ), enable_kvcache_transpose=False, device=self.device, @@ -1681,7 +1680,8 @@ class ModelRunner: need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") if self.token_to_kv_pool_allocator is None: if _is_npu and ( - self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn + self.server_args.attention_backend == "ascend" + or self.hybrid_gdn_config is not None ): self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.max_total_num_tokens, diff --git a/python/sglang/srt/models/falcon_h1.py b/python/sglang/srt/models/falcon_h1.py index a035e0291..f05a395d9 100644 --- a/python/sglang/srt/models/falcon_h1.py +++ b/python/sglang/srt/models/falcon_h1.py @@ -8,6 +8,10 @@ from torch import nn from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( + HybridLinearAttnBackend, + Mamba2AttnBackend, +) from sglang.srt.layers.attention.mamba.mamba import MambaMixer2 from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes from sglang.srt.layers.dp_attention import ( @@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module): ) self.mamba = MambaMixer2( + cache_params=config.mamba2_cache_params, hidden_size=config.hidden_size, - ssm_state_size=config.mamba_d_state, - conv_kernel_size=config.mamba_d_conv, - intermediate_size=self.d_ssm, use_conv_bias=config.mamba_conv_bias, use_bias=config.mamba_proj_bias, n_groups=config.mamba_n_groups, - num_heads=config.mamba_n_heads, - layer_id=layer_id, - head_dim=config.mamba_d_head, rms_norm_eps=config.rms_norm_eps, - chunk_size=config.mamba_chunk_size, activation=config.hidden_act, use_rms_norm=config.mamba_rms_norm, prefix=f"{prefix}.mixer", @@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module): ) attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + attn_backend = forward_batch.attn_backend + assert isinstance(attn_backend, HybridLinearAttnBackend) + assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend) # Mamba block mamba_hidden_states = torch.empty_like(hidden_states) - self.mamba( + attn_backend.linear_attn_backend.forward( + self.mamba, hidden_states * self.ssm_in_multiplier, mamba_hidden_states, - forward_batch=forward_batch, + layer_id=self.layer_id, mup_vector=self.mup_vector, ) mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier diff --git a/python/sglang/srt/models/nemotron_h.py b/python/sglang/srt/models/nemotron_h.py new file mode 100644 index 000000000..9f0126c3f --- /dev/null +++ b/python/sglang/srt/models/nemotron_h.py @@ -0,0 +1,514 @@ +# Copyright 2023-2025 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py + +"""Inference-only NemotronH model.""" + +from collections.abc import Iterable +from typing import Optional, Union + +import torch +from torch import nn + +from sglang.srt.configs import NemotronHConfig +from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import ReLU2 +from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( + HybridLinearAttnBackend, + Mamba2AttnBackend, +) +from sglang.srt.layers.attention.mamba.mamba import MambaMixer2 +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from sglang.srt.utils import add_prefix, make_layers_non_pp +from sglang.utils import logger + + +class NemotronHMLP(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + + hybrid_override_pattern = config.hybrid_override_pattern + mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1 + if isinstance(config.intermediate_size, list): + if len(config.intermediate_size) == 1: + intermediate_size = config.intermediate_size[0] + else: + intermediate_size = config.intermediate_size[mlp_index] + else: + intermediate_size = config.intermediate_size + + self.up_proj = ColumnParallelLinear( + input_size=config.hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=config.hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = ReLU2() + + def forward(self, x: torch.Tensor): + x, _ = self.up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +class NemotronHMLPDecoderLayer(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + self.mixer = NemotronHMLP( + config, + quant_config=quant_config, + bias=config.mlp_bias, + prefix=f"{prefix}.mixer", + layer_idx=layer_idx, + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + *, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer.forward(hidden_states) + return hidden_states, residual + + +class NemotronHMambaDecoderLayer(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.layer_id = layer_idx + self.mixer = MambaMixer2( + cache_params=config.mamba2_cache_params, + hidden_size=config.hidden_size, + use_conv_bias=config.use_conv_bias, + use_bias=config.use_bias, + n_groups=config.mamba_n_groups, + rms_norm_eps=config.rms_norm_eps, + activation=config.mamba_hidden_act, + quant_config=quant_config, + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + *, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + output = torch.empty_like(hidden_states) + attn_backend = forward_batch.attn_backend + assert isinstance(attn_backend, HybridLinearAttnBackend) + assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend) + attn_backend.linear_attn_backend.forward( + mixer=self.mixer, + layer_id=self.layer_id, + hidden_states=hidden_states, + output=output, + use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv` + ) + return output, residual + + +class NemotronHAttention(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + if hasattr(config, "head_dim") and config.head_dim is not None: + self.head_dim = config.head_dim + else: + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_idx, + quant_config=quant_config, + prefix=add_prefix("attn", prefix), + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn.forward(q, k, v, forward_batch) + output, _ = self.o_proj(attn_output) + return output + + +class NemotronHAttentionDecoderLayer(nn.Module): + def __init__( + self, + config: NemotronHConfig, + layer_idx: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.mixer = NemotronHAttention( + config, + layer_idx, + quant_config, + prefix=f"{prefix}.mixer", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + *, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + ) -> tuple[torch.Tensor, torch.Tensor]: + if residual is None: + residual = hidden_states + hidden_states = self.norm(hidden_states) + else: + hidden_states, residual = self.norm(hidden_states, residual) + + hidden_states = self.mixer.forward( + hidden_states=hidden_states, forward_batch=forward_batch + ) + return hidden_states, residual + + +Layers = ( + NemotronHAttentionDecoderLayer + | NemotronHMLPDecoderLayer + | NemotronHMambaDecoderLayer +) +ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = { + ATTENTION: NemotronHAttentionDecoderLayer, + MLP: NemotronHMLPDecoderLayer, + MAMBA: NemotronHMambaDecoderLayer, +} + + +class NemotronHModel(nn.Module): + def __init__( + self, + *, + config: NemotronHConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + lora_config = None + self.config = config + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + def get_layer(idx: int, prefix: str): + layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]] + return layer_class(config, idx, quant_config=quant_config, prefix=prefix) + + self.layers = make_layers_non_pp( + len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers" + ) + self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + residual = None + for layer in self.layers: + if not isinstance(layer, Layers): + raise ValueError(f"Unknown layer type: {type(layer)}") + hidden_states, residual = layer.forward( + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + if not get_pp_group().is_last_rank: + return PPProxyTensors( + {"hidden_states": hidden_states, "residual": residual} + ) + hidden_states, _ = self.norm_f(hidden_states, residual) + return hidden_states + + +class NemotronHForCausalLM(nn.Module): + remap_prefix = {"backbone": "model"} + remap_substr = {"A_log": "A", "embeddings": "embed_tokens"} + + # LoRA specific attributes + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + *, + config: NemotronHConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + lora_config = None + self.config = config + self.model = self._init_model( + config=config, quant_config=quant_config, prefix=prefix + ) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config + else lora_config.lora_vocab_padding_size + ), + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + ) + self.logits_processor = LogitsProcessor(config) + + def _init_model( + self, + config: NemotronHConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ): + hidden_states = self.model.forward( + input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds + ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs): + return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs) + + def get_seqlen_agnostic_capture_inputs(self, batch_size: int): + return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + updated_weights = [] + for name, loaded_weight in weights: + for prefix, new_key in self.remap_prefix.items(): + if name.startswith(prefix): + name = name.replace(prefix, new_key) + for substr, new_key in self.remap_substr.items(): + if substr in name: + name = name.replace(substr, new_key) + updated_weights.append((name, loaded_weight)) + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in updated_weights: + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name in params_dict.keys(): + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + else: + logger.warning(f"Parameter {name} not found in params_dict") + + +EntryClass = [NemotronHForCausalLM] diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index d23f9e3b0..430e38eb4 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker): logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] # QQ: can be optimized - if self.target_worker.model_runner.is_hybrid_gdn: + if self.target_worker.model_runner.hybrid_gdn_config is not None: # res.draft_input.accept_length is on GPU but may be empty for last verify? accepted_length = ( torch.tensor( diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 0ab2783c3..0e6828c7e 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -518,6 +518,24 @@ def make_layers( return modules, start_layer, end_layer +def make_layers_non_pp( + num_hidden_layers: int, + layer_fn: LayerFn, + prefix: str = "", +) -> torch.nn.ModuleList: + from sglang.srt.offloader import get_offloader + + layers = torch.nn.ModuleList( + get_offloader().wrap_modules( + ( + layer_fn(idx=idx, prefix=add_prefix(idx, prefix)) + for idx in range(num_hidden_layers) + ) + ) + ) + return layers + + cmo_stream = None diff --git a/python/sglang/srt/utils/hf_transformers_utils.py b/python/sglang/srt/utils/hf_transformers_utils.py index 75e0a8b75..c967d8683 100644 --- a/python/sglang/srt/utils/hf_transformers_utils.py +++ b/python/sglang/srt/utils/hf_transformers_utils.py @@ -45,6 +45,7 @@ from sglang.srt.configs import ( KimiVLConfig, LongcatFlashConfig, MultiModalityConfig, + NemotronHConfig, Qwen3NextConfig, Step3VLConfig, ) @@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { FalconH1Config.model_type: FalconH1Config, DotsVLMConfig.model_type: DotsVLMConfig, DotsOCRConfig.model_type: DotsOCRConfig, + NemotronHConfig.model_type: NemotronHConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/test/srt/layers/attention/mamba/test_causal_conv1d.py b/test/srt/layers/attention/mamba/test_causal_conv1d.py new file mode 100644 index 000000000..c56b96b4f --- /dev/null +++ b/test/srt/layers/attention/mamba/test_causal_conv1d.py @@ -0,0 +1,375 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/mamba/test_causal_conv1d.py + + +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange + +from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( + PAD_SLOT_ID, + causal_conv1d_fn, + causal_conv1d_update, +) + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref( + x, conv_state, weight, bias=None, activation=None, cache_seqlens=None +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to( + weight.dtype + ) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange( + -(width - 1), 0, dtype=torch.long, device=x.device + ).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = ( + torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + ) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze( + 0 + ) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[ + :, :, -seqlen: + ] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) +def causal_conv1d_opcheck_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.manual_seed(0) + batch = 2 + x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + x_ref = x.clone() + conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) + + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation) + out_ref = causal_conv1d_update_ref( + x_ref, conv_state_ref, weight, bias, activation=activation + ) + + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("seqlen", [1, 3]) +@pytest.mark.parametrize("width", [3, 4]) +@pytest.mark.parametrize("dim", [2048 + 16, 4096]) +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch_size", [3]) +def test_causal_conv1d_update_with_batch_gather( + batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + + # set seed + torch.manual_seed(0) + + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + # total_entries = number of cache line + total_entries = 10 * batch_size + + # x will be (batch, dim, seqlen) with contiguous along dim-axis + x = torch.randn( + padded_batch_size, seqlen, dim, device=device, dtype=itype + ).transpose(1, 2) + + x_ref = x.clone() + + conv_state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) + unused_states_bool[conv_state_indices] = False + padded_state_indices = torch.concat( + [ + conv_state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) + + # conv_state will be (cache_lines, dim, state_len) + # with contiguous along dim-axis + conv_state = torch.randn( + total_entries, width - 1, dim, device=device, dtype=itype + ).transpose(1, 2) + + conv_state_for_padding_test = conv_state.clone() + + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + conv_state_ref = conv_state[conv_state_indices, :].detach().clone() + activation = None if not silu_activation else "silu" + + out = causal_conv1d_update( + x, + conv_state, + weight, + bias, + activation=activation, + conv_state_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + ) + out_ref = causal_conv1d_update_ref( + x_ref[:batch_size], conv_state_ref, weight, bias, activation=activation + ) + + assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) + assert torch.equal( + conv_state[unused_states_bool], conv_state_for_padding_test[unused_states_bool] + ) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize("seqlen", [8, 30, 249, 2049, 4096]) +@pytest.mark.parametrize("dim", [64, 4096]) +@pytest.mark.parametrize("with_padding", [True, False]) +@pytest.mark.parametrize("batch", [4, 10]) +def test_causal_conv1d_varlen( + batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + device = "cuda" + torch.cuda.empty_cache() + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.manual_seed(0) + seqlens = [] + batch_size = batch + padding = 3 if with_padding else 0 + padded_batch_size = batch_size + padding + nsplits = padded_batch_size - 1 + + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + + seqlens.append( + torch.diff( + torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])]) + ).tolist() + ) + assert sum(seqlens[-1]) == seqlen + assert all(s > 0 for s in seqlens[-1]) + + total_entries = batch_size * 10 + cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], dim=0) + x = rearrange( + torch.randn(1, seqlen, 4096 + dim + 64, device=device, dtype=itype), + "b s d -> b d s", + )[:, 4096 : 4096 + dim, :] + + weight = torch.randn(dim, width, device=device, dtype=itype) + + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + x_ref = x.clone() + weight_ref = weight.clone() + bias_ref = bias.clone() if bias is not None else None + activation = None if not silu_activation else "silu" + final_states = torch.randn( + total_entries, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + final_states_ref = final_states.clone() + has_initial_states = torch.randint( + 0, 2, (cumsum.shape[0] - 1,), dtype=torch.bool, device=x.device + ) + state_indices = torch.randperm(total_entries, dtype=torch.int32, device=x.device)[ + :batch_size + ] + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=-1, + ) + out = causal_conv1d_fn( + x.squeeze(0), + weight, + bias=bias, + conv_states=final_states, + query_start_loc=cumsum.cuda(), + seq_lens_cpu=torch.tensor(seqlens[0]), + cache_indices=padded_state_indices, + has_initial_state=has_initial_states, + activation=activation, + pad_slot_id=PAD_SLOT_ID, + ) + + out_ref = [] + out_ref_b = [] + + splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] + for i in range(len(seqlens[0])): + x_s = [v[i].unsqueeze(0) for v in splits][0] + if padded_state_indices[i] == PAD_SLOT_ID: + continue + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight_ref, + bias_ref, + activation=activation, + return_final_states=True, + final_states_out=final_states_ref[padded_state_indices[i]].unsqueeze(0), + initial_states=( + final_states_ref[padded_state_indices[i]].unsqueeze(0) + if has_initial_states[i] + else None + ), + ) + ) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) + out_ref_tensor = torch.cat(out_ref, dim=0) + + assert torch.allclose( + final_states[state_indices], + final_states_ref[state_indices], + rtol=rtol, + atol=atol, + ) + unpadded_out = out[:, : out_ref_tensor.shape[-1]] + assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) diff --git a/test/srt/layers/attention/mamba/test_mamba2_mixer.py b/test/srt/layers/attention/mamba/test_mamba2_mixer.py new file mode 100644 index 000000000..aae477db5 --- /dev/null +++ b/test/srt/layers/attention/mamba/test_mamba2_mixer.py @@ -0,0 +1,138 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/tests/kernels/mamba/test_mamba_mixer2.py + +from unittest.mock import patch + +import pytest +import torch + +from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import ( + update_environment_variables, +) +from sglang.srt.distributed.parallel_state import ( + init_distributed_environment, + initialize_model_parallel, +) + +NUM_GPUS = 2 + + +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("seq_len", [128]) +@pytest.mark.parametrize( + "hidden_size_n_groups", + [ + (64, 1), # hidden_size be divisible by num_gpus + (100, 4), # and n_groups must divide hidden_size + ], +) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_mixer2_gated_norm_multi_gpu( + batch_size: int, + seq_len: int, + hidden_size_n_groups: tuple[int, int], + dtype: torch.dtype, + device: str = "cuda", +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + assert torch.cuda.device_count() == NUM_GPUS + + hidden_size, n_groups = hidden_size_n_groups + num_processes = NUM_GPUS + + def run_torch_spawn(fn, nprocs): + # need to use torch.mp.spawn otherwise will have problems with + # torch.distributed and cuda + torch.multiprocessing.spawn( + fn, + args=( + num_processes, + batch_size, + seq_len, + hidden_size, + n_groups, + dtype, + device, + ), + nprocs=nprocs, + ) + + run_torch_spawn(mixer2_gated_norm_tensor_parallel, NUM_GPUS) + + +def mixer2_gated_norm_tensor_parallel( + local_rank: int, + world_size: int, + batch_size: int, + seq_len: int, + hidden_size: int, + n_groups: int, + dtype: torch.dtype, + device: str, +): + torch.manual_seed(0) + + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(device) + torch.set_default_device(device) + torch.set_default_dtype(dtype) + + update_environment_variables( + { + "RANK": str(local_rank), + "LOCAL_RANK": str(local_rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": "12345", + } + ) + + # initialize distributed + init_distributed_environment( + world_size=world_size, rank=local_rank, local_rank=local_rank + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + + # create random weights an inputs + weight = torch.rand((hidden_size,), dtype=dtype, device=device) + hidden_states = torch.randn(batch_size, seq_len, hidden_size) + gate_states = torch.randn(batch_size, seq_len, hidden_size) + + import sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated as m2 + import sglang.srt.model_loader.weight_utils as wu + + # Convenience: Avoid calling initialize_dp_attention + with patch.object(wu, "get_attention_tp_rank", return_value=local_rank): + # create gated-norm with TP + mixer = m2.Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + mixer.weight.weight_loader(mixer.weight, weight) + + with ( + patch.object(m2, "get_tensor_model_parallel_world_size", return_value=1), + patch.object(m2, "get_tensor_model_parallel_rank", return_value=0), + ): + # create gated-norm without TP to compute reference + mixer_single_gpu = m2.Mixer2RMSNormGated( + full_hidden_size=hidden_size, + full_n_groups=n_groups, + ) + # assign weight to single-gpu mixer + mixer_single_gpu.weight.data = weight + + # generate and compare + N = hidden_size // world_size + output = mixer( + hidden_states[..., local_rank * N : (local_rank + 1) * N], + gate_states[..., local_rank * N : (local_rank + 1) * N], + ) + ref_output = mixer_single_gpu(hidden_states, gate_states) + torch.testing.assert_close( + output, + ref_output[..., local_rank * N : (local_rank + 1) * N], + atol=5e-3, + rtol=1e-3, + ) diff --git a/test/srt/layers/attention/mamba/test_mamba_ssm.py b/test/srt/layers/attention/mamba/test_mamba_ssm.py new file mode 100644 index 000000000..3e983a00e --- /dev/null +++ b/test/srt/layers/attention/mamba/test_mamba_ssm.py @@ -0,0 +1,291 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from sglang.srt.layers.attention.mamba.causal_conv1d_triton import PAD_SLOT_ID +from sglang.srt.layers.attention.mamba.ops import selective_state_update + + +def selective_state_update_ref( + state, x, dt, A, B, C, D=None, z=None, dt_bias=None, dt_softplus=False +): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp( + rearrange(dt, "b h d -> b h d 1") * A + ) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n" + ) # (batch, nheads, dim, dstate) + state.copy_( + state * dA + dB * rearrange(x, "b h d -> b h d 1") + ) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update(dim, dstate, has_z, itype): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + device = "cuda" + + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.manual_seed(0) + batch_size = 1 + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.detach().clone() + selective_state_update( + state, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True, out=out + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) + + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +# tests correctness in case subset of the sequences are padded +@pytest.mark.parametrize("with_padding", [True, False]) +def test_selective_state_update_with_batch_indices( + with_padding, dim, dstate, has_z, itype +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 3 + padding = 5 if with_padding else 0 + padded_batch_size = batch_size + padding + total_entries = 10 * batch_size + state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device + ) + unused_states_bool = torch.ones(total_entries, dtype=torch.bool, device=device) + unused_states_bool[state_indices] = False + padded_state_indices = torch.concat( + [ + state_indices, + torch.as_tensor([PAD_SLOT_ID] * padding, dtype=torch.int32, device=device), + ], + dim=0, + ) + x = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + out = torch.empty_like(x) + dt = torch.randn(padded_batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(padded_batch_size, dstate, device=device) + C = torch.randn(padded_batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].clone() + state_before = state.clone() + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=padded_state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out, + ) + out_ref = selective_state_update_ref( + state_ref, + x[:batch_size], + dt[:batch_size], + A, + B[:batch_size], + C[:batch_size], + D=D, + z=z[:batch_size], + dt_bias=dt_bias, + dt_softplus=True, + ) + + print("Output diff max", (out[:batch_size] - out_ref).max()) + print("Output diff mean", (out[:batch_size] - out_ref).mean()) + print("Output state diff max", (state[state_indices, :] - state_ref).max()) + print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) + # test padded entries stay the same + if with_padding: + assert torch.equal(state_before[unused_states_bool], state[unused_states_bool]) + assert torch.equal(x[batch_size + 1 :], x[batch_size + 1 :]) + assert torch.equal(dt[batch_size + 1 :], dt[batch_size + 1 :]) + assert torch.equal(B[batch_size + 1 :], B[batch_size + 1 :]) + assert torch.equal(C[batch_size + 1 :], C[batch_size + 1 :]) + + # test "real" entries + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("tie_hdim", [False, True]) +@pytest.mark.parametrize("ngroups", [1, 2, 4]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +def test_selective_state_update_with_heads_with_batch_indices( + dim, dstate, ngroups, has_z, tie_hdim, itype +): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-1, 1e-1 + # set seed + torch.random.manual_seed(0) + batch_size = 3 + headdim = 64 + nheads = dim // headdim + + total_entries = 10 * batch_size + state = torch.randn( + total_entries, nheads, headdim, dstate, dtype=itype, device=device + ) + state_indices = torch.randperm(total_entries)[:batch_size].to( + dtype=torch.int32, device=device + ) + + x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + out = torch.empty_like(x) + if not tie_hdim: + dt = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype) + dt_bias = torch.rand(nheads, headdim, device=device) - 4.0 + A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0 + D = torch.randn(nheads, headdim, device=device) + else: + dt = repeat( + torch.randn(batch_size, nheads, device=device, dtype=itype), + "b h -> b h p", + p=headdim, + ) + dt_bias = repeat(torch.rand(nheads, device=device) - 4.0, "h -> h p", p=headdim) + A = repeat( + -torch.rand(nheads, device=device) - 1.0, "h -> h p n", p=headdim, n=dstate + ) + D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim) + B = torch.randn(batch_size, ngroups, dstate, device=device) + C = torch.randn(batch_size, ngroups, dstate, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state[state_indices, :].detach().clone() + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices, + pad_slot_id=PAD_SLOT_ID, + out=out, + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py b/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py new file mode 100644 index 000000000..493a179ee --- /dev/null +++ b/test/srt/layers/attention/mamba/test_mamba_ssm_ssd.py @@ -0,0 +1,581 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/633f943e30a4444d890d26b81850f7217736f840/tests/kernels/mamba/test_mamba_ssm_ssd.py + + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata +from sglang.srt.layers.attention.mamba.ops import mamba_chunk_scan_combined + +# Added by the IBM Team, 2024 + +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py + +# TODO: These take a long time to run - we should cut down on some of the parameterized matrix. + + +# this is the segsum implementation taken from above +def segsum(x): + """Calculates segment sum.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = ( + rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C) + ) + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at + # chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms + # (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +def generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype, device="cuda"): + + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + torch.manual_seed(0) + A = -torch.exp(torch.rand(n_heads, dtype=itype, device=device)) + dt = F.softplus( + torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - 4 + ) + X = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + B = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + C = torch.randn((batch_size, seqlen, n_heads, d_head), dtype=itype, device=device) + + return A, dt, X, B, C + + +def generate_continuous_batched_examples( + example_lens_by_batch, + num_examples, + full_length, + last_taken, + exhausted, + n_heads, + d_head, + itype, + device="cuda", + return_naive_ref=True, +): + + # this function generates a random examples of certain length + # and then cut according to "example_lens_by_batch" and feed + # them in continuous batches to the kernels. + # If if return_naive_ref=True, the naive torch implementation + # ssd_minimal_discrete will be used to compute and return + # reference output. + + # generate the full-length example + A, dt, X, B, C = generate_random_inputs( + num_examples, full_length, n_heads, d_head, itype + ) + + if return_naive_ref: + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, block_len=full_length // 4 + ) + + # internal function that outputs a cont batch of examples + # given a tuple of lengths for each example in the batch + # e.g., example_lens=(8, 4) means take 8 samples from first eg, + # 4 examples from second eg, etc + def get_continuous_batch(example_lens: tuple[int, ...]): + + indices = [] + for i, x in enumerate(example_lens): + c = last_taken.get(i, 0) + indices.append((c, c + x)) + last_taken[i] = (c + x) % full_length + exhausted[i] = last_taken[i] == 0 + + return ( + torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices)]).unsqueeze(0) + for x in (dt, X, B, C) + ) + + # internal function that maps "n" to the appropriate right boundary + # value when forming continuous batches from examples of length given + # by "full_length". + # - e.g., when n > full_length, returns n % full_length + # when n == full_length, returns full_length + def end_boundary(n: int): + return n - ((n - 1) // full_length) * full_length + + IND_E = None + for spec in example_lens_by_batch: + + # get the (maybe partial) example seen in this cont batch + dt2, X2, B2, C2 = get_continuous_batch(spec) + + # get the metadata + cu_seqlens = torch.tensor((0,) + spec, device=device).cumsum(dim=0) + seq_idx = torch.zeros( + cu_seqlens[-1], dtype=torch.int32, device=cu_seqlens.device + ) + for i, (srt, end) in enumerate( + zip( + cu_seqlens, + cu_seqlens[1:], + ) + ): + seq_idx[srt:end] = i + + # for cont batch + if IND_E is None: + IND_S = [0 for _ in range(len(spec))] + else: + IND_S = [x % full_length for x in IND_E] + IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] + + yield ( + ( + [Y_min[s, IND_S[s] : IND_E[s]] for s in range(num_examples)] + if return_naive_ref + else None + ), + cu_seqlens, + seq_idx.unsqueeze(0), + (A, dt2, X2, B2, C2), + ) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) +@pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) +@pytest.mark.parametrize("seq_len_chunk_size", [(112, 16), (128, 32)]) +def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, itype): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + # this tests the kernels on a single example (no batching) + + # TODO: the bfloat16 case requires higher thresholds. To be investigated + + if itype == torch.bfloat16: + atol, rtol = 5e-2, 5e-2 + else: + atol, rtol = 8e-3, 5e-3 + + # set seed + batch_size = 1 # batch_size + # ssd_minimal_discrete requires chunk_size divide seqlen + # - this is only required for generating the reference seqs, + # it is not an operational limitation. + seqlen, chunk_size = seq_len_chunk_size + + A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, d_head, itype) + + Y_min, final_state_min = ssd_minimal_discrete( + X * dt.unsqueeze(-1), A * dt, B, C, chunk_size + ) + Y = torch.empty_like(X) + final_state = mamba_chunk_scan_combined( + X, dt, A, B, C, chunk_size, D=None, return_final_states=True, out=Y + ) + + # just test the last in sequence + torch.testing.assert_close(Y[:, -1], Y_min[:, -1], atol=atol, rtol=rtol) + + # just test the last head + # NOTE, in the kernel we always cast states to fp32 + torch.testing.assert_close( + final_state[:, -1], + final_state_min[:, -1].to(torch.float32), + atol=atol, + rtol=rtol, + ) + + +@pytest.mark.parametrize("itype", [torch.float32, torch.float16]) +@pytest.mark.parametrize("n_heads", [4, 8, 13]) +@pytest.mark.parametrize("d_head", [5, 16, 21, 32]) +@pytest.mark.parametrize( + "seq_len_chunk_size_cases", + [ + # small-ish chunk_size (8) + (64, 8, 2, [(64, 32), (64, 32)]), + (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), + (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary + ( + 64, + 8, + 2, + [(4, 4), (4, 4), (4, 4), (4, 4)], + ), # chunk_size larger than cont batches + ( + 64, + 8, + 5, + [ + (64, 32, 16, 8, 8), + (8, 16, 32, 16, 8), + (8, 8, 16, 32, 16), + ], + ), # mode examples with varied lengths + # large-ish chunk_size (256) + (64, 256, 1, [(5,), (1,), (1,), (1,)]), # irregular sizes with small sequences + ( + 64, + 256, + 2, + [(5, 30), (1, 2), (1, 2), (1, 2)], + ), # irregular sizes with small sequences + # we also need to test some large seqlen + # to catch errors with init states decay + (768, 128, 2, [(138, 225), (138, 225)]), + ], +) +def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, itype): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + # this test with multiple examples in a continuous batch + # (i.e. chunked prefill) + + seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases + + # This test can have larger error for longer sequences + if seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted + + states = None + for ( + Y_min, + cu_seqlens, + seq_idx, + (A, dt, X, B, C), + ) in generate_continuous_batched_examples( + cases, num_examples, seqlen, last_taken, exhausted, n_heads, d_head, itype + ): + + chunk_indices, chunk_offsets = ( + Mamba2Metadata._query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1] + ) + ) + + Y = torch.empty_like(X) + new_states = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=states, + out=Y, + ) + + # just test the last in sequence + for i in range(num_examples): + + # just test one dim and dstate + Y_eg = Y[0, cu_seqlens[i] : cu_seqlens[i + 1], 0, 0] + Y_min_eg = Y_min[i][:, 0, 0] + torch.testing.assert_close(Y_eg, Y_min_eg, atol=atol, rtol=rtol) + + # update states + states = new_states + for i, clear in exhausted.items(): + if clear: + states[i].fill_(0.0) + exhausted[i] = False + + +@pytest.mark.parametrize("chunk_size", [8, 256]) +@pytest.mark.parametrize( + "seqlens", + [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), + ], +) +def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens): + if not torch.cuda.is_available(): + pytest.skip("CUDA device not available") + + # This test verifies the correctness of the chunked prefill implementation + # in the mamba2 ssd kernels, by comparing concatenation (in the sequence + # dimension) of chunked results with the full sequence result. + # It is different from test_mamba_chunk_scan_cont_batch by: + # 1. Not using the naive torch implementation (ssd_minimal_discrete) to get + # reference outputs. Instead, it compares chunked kernel outputs to full + # sequence kernel outputs. This is the most straightforward way to + # assert chunked prefill correctness. + # 2. It focuses on cases where sequences change in the middle of mamba + # chunks, and not necessarily on chunk boundaries. + + max_seqlen = max(seqlens) + # This test can have larger error for longer sequences + if max_seqlen > 256: + atol, rtol = 1e-2, 5e-3 + else: + atol, rtol = 5e-3, 5e-3 + + num_sequences = len(seqlens) + n_heads = 16 + d_head = 64 + itype = torch.float32 + + # hold state during the cutting process so we know if an + # example has been exhausted and needs to cycle + last_taken: dict = {} # map: eg -> pointer to last taken sample + exhausted: dict = {} # map: eg -> boolean indicating example is exhausted + _, cu_seqlens, seq_idx, (A, dt, X, B, C) = next( + generate_continuous_batched_examples( + [seqlens], + num_sequences, + max_seqlen, + last_taken, + exhausted, + n_heads, + d_head, + itype, + return_naive_ref=False, + ) + ) + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device) + device = X.device + + ## full seqlen computation + chunk_indices, chunk_offsets = ( + Mamba2Metadata._query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1] + ) + ) + Y_ref = torch.empty_like(X) + state_ref = mamba_chunk_scan_combined( + X, + dt, + A, + B, + C, + chunk_size, + D=None, + cu_seqlens=cu_seqlens, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_ref, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat( + [torch.tensor([0], device=device), torch.cumsum(chunked_seqlens, dim=0)], dim=0 + ) + chunked_seq_idx = ( + torch.repeat_interleave( + torch.arange(len(chunked_seqlens), device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1], + ) + .unsqueeze(0) + .to(torch.int32) + ) + chunked_input_seq_len = chunked_cu_seqlens[-1] + X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + for i in range(num_sequences): + # fmt: off + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501 + + X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501 + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501 + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501 + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501 + # fmt: on + + chunk_indices, chunk_offsets = ( + Mamba2Metadata._query_start_loc_to_chunk_indices_offsets( + chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1] + ) + ) + Y_partial = torch.empty_like(X_chunked) + partial_state = mamba_chunk_scan_combined( + X_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size, + D=None, + cu_seqlens=chunked_cu_seqlens, + seq_idx=chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=None, + out=Y_partial, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat( + [ + torch.tensor([0], device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0), + ], + dim=0, + ) + remaining_chunked_seq_idx = ( + torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1], + ) + .unsqueeze(0) + .to(torch.int32) + ) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # fmt: off + remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501 + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501 + + remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501 + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501 + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501 + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501 + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501 + # fmt: on + + assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + + chunk_indices, chunk_offsets = ( + Mamba2Metadata._query_start_loc_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, chunk_size, remaining_chunked_cu_seqlens[-1] + ) + ) + + Y_chunked = torch.empty_like(remaining_X_chunked) + state_chunked = mamba_chunk_scan_combined( + remaining_X_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size, + D=None, + cu_seqlens=remaining_chunked_cu_seqlens, + seq_idx=remaining_chunked_seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + return_varlen_states=True, + initial_states=partial_state, + out=Y_chunked, + ) + Y = concat_batch_f(Y_partial, Y_chunked) + + # kernel chunked is same as kernel overall + for i in range(num_sequences): + Y_seq = Y[:, cu_seqlens[i] : cu_seqlens[i + 1], ...] + Y_ref_seq = Y_ref[:, cu_seqlens[i] : cu_seqlens[i + 1], ...] + torch.testing.assert_close( + Y_seq[:, : chunked_seqlens[i], ...], + Y_ref_seq[:, : chunked_seqlens[i], ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part1 " + x, + ) # noqa: B023 + torch.testing.assert_close( + Y_seq[:, chunked_seqlens[i] :, ...], + Y_ref_seq[:, chunked_seqlens[i] :, ...], + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} output part2 " + x, + ) # noqa: B023 + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close( + state_seq, + state_seq_ref, + atol=atol, + rtol=rtol, + msg=lambda x: f"seq{i} state " + x, + ) # noqa: B023 diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index ef930f88e..d6c576471 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -91,6 +91,11 @@ ALL_MODELS = [ trust_remote_code=True, skip_long_prompt=True, ), + ModelCase( + "nvidia/NVIDIA-Nemotron-Nano-9B-v2", + trust_remote_code=True, + skip_long_prompt=True, + ), ModelCase( "swiss-ai/Apertus-8B", trust_remote_code=True, diff --git a/test/srt/models/test_nvidia_nemotron_nano_v2.py b/test/srt/models/test_nvidia_nemotron_nano_v2.py new file mode 100644 index 000000000..840de091c --- /dev/null +++ b/test/srt/models/test_nvidia_nemotron_nano_v2.py @@ -0,0 +1,44 @@ +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestNvidiaNemotronNanoV2(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--max-mamba-cache-size", + "256", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.87) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 7eb82e36e..4c5a0779f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -127,6 +127,10 @@ suites = { TestFile("test_vlm_input_format.py", 300), TestFile("test_vision_openai_server_a.py", 724), TestFile("test_vision_openai_server_b.py", 446), + TestFile("layers/attention/mamba/test_causal_conv1d.py", 85), + TestFile("layers/attention/mamba/test_mamba_ssm.py", 85), + TestFile("layers/attention/mamba/test_mamba_ssm_ssd.py", 220), + TestFile("models/test_nvidia_nemotron_nano_v2.py", 180), TestFile("test_modelopt_loader.py", 30), ], "per-commit-2-gpu": [ @@ -142,6 +146,7 @@ suites = { TestFile("hicache/test_hicache_storage_file_backend.py", 200), TestFile("hicache/test_hicache_storage_mooncake_backend.py", 400), TestFile("hicache/test_hicache_storage_3fs_backend.py", 200), + TestFile("layers/attention/mamba/test_mamba2_mixer.py", 110), ], "per-commit-4-gpu": [ TestFile("test_gpt_oss_4gpu.py", 300),