model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
||||
from sglang.srt.configs.nemotron_h import NemotronHConfig
|
||||
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
@@ -32,4 +33,5 @@ __all__ = [
|
||||
"DotsVLMConfig",
|
||||
"DotsOCRConfig",
|
||||
"FalconH1Config",
|
||||
"NemotronHConfig",
|
||||
]
|
||||
|
||||
@@ -15,16 +15,12 @@
|
||||
"""Falcon-H1 model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig):
|
||||
self.rope_scaling = None
|
||||
self.rope_scaling = rope_scaling
|
||||
self.projectors_bias = projectors_bias
|
||||
mamba_intermediate = (
|
||||
self.mamba_intermediate = mamba_intermediate = (
|
||||
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
|
||||
)
|
||||
|
||||
@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig):
|
||||
def layers_block_type(self):
|
||||
return ["falcon_h1" for i in range(self.num_hidden_layers)]
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
# For Falcon-H1, we do have attention on all layers
|
||||
@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig):
|
||||
return range(self.num_hidden_layers)
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
n_groups = self.mamba_n_groups
|
||||
if self.mamba_n_groups % world_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
self.mamba_n_groups, world_size
|
||||
)
|
||||
n_groups += extra_groups
|
||||
|
||||
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
|
||||
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.mamba_d_conv - 1,
|
||||
)
|
||||
|
||||
# we TP-ize on the heads dimension
|
||||
temporal_state_shape = (
|
||||
self.mamba_d_state,
|
||||
self.mamba_d_head,
|
||||
divide(self.mamba_n_heads, world_size),
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
def mamba2_cache_params(self):
|
||||
shape = Mamba2StateShape.create(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.mamba_intermediate,
|
||||
n_groups=self.mamba_n_groups,
|
||||
num_heads=self.mamba_n_heads,
|
||||
head_dim=self.mamba_d_head,
|
||||
state_size=self.mamba_d_state,
|
||||
conv_kernel=self.mamba_d_conv,
|
||||
)
|
||||
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
|
||||
|
||||
117
python/sglang/srt/configs/mamba_utils.py
Normal file
117
python/sglang/srt/configs/mamba_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
|
||||
|
||||
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class Mamba2StateShape:
|
||||
conv: tuple[int, int]
|
||||
temporal: tuple[int, int, int]
|
||||
|
||||
intermediate_size: int
|
||||
conv_dim: int
|
||||
ssm_state_size: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
state_size: int
|
||||
conv_kernel: int
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> "Mamba2StateShape":
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
if n_groups % tp_world_size != 0:
|
||||
extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
n_groups += extra_groups
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return Mamba2StateShape(
|
||||
conv=conv_state_shape,
|
||||
temporal=temporal_state_shape,
|
||||
intermediate_size=intermediate_size,
|
||||
conv_dim=conv_dim,
|
||||
ssm_state_size=state_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
state_size=state_size,
|
||||
conv_kernel=conv_kernel,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class Mamba2StateDType:
|
||||
conv: torch.dtype
|
||||
temporal: torch.dtype
|
||||
|
||||
|
||||
CONV_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def mamba2_state_dtype() -> Mamba2StateDType:
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class Mamba2CacheParams:
|
||||
shape: Mamba2StateShape
|
||||
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
|
||||
layers: list[int]
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self) -> int:
|
||||
return (
|
||||
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
|
||||
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
|
||||
) * len(self.layers)
|
||||
286
python/sglang/srt/configs/nemotron_h.py
Normal file
286
python/sglang/srt/configs/nemotron_h.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
|
||||
|
||||
"""NemotronH model configuration"""
|
||||
|
||||
import regex as re
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAMBA = "M"
|
||||
ATTENTION = "*"
|
||||
MLP = "-"
|
||||
|
||||
|
||||
class NemotronHConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a
|
||||
[`NemotronHModel`]. It is used to instantiate a NemotronH model according
|
||||
to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to
|
||||
that of the NemotronH-v0.1 model.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 131072):
|
||||
Vocabulary size of the NemotronH model. Defines the number of
|
||||
different tokens that can be represented by the `inputs_ids`
|
||||
passed when calling [`NemotronHModel`]
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be
|
||||
tied. Note that this is only relevant if the model has an output
|
||||
word embedding layer.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 21504):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 52):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
hybrid_override_pattern (`str`, *optional*, defaults to
|
||||
`"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
|
||||
The pattern of the hybrid model. The pattern is a string of
|
||||
characters where each character represents
|
||||
M: Mamba2, *: Attention, -: MLP
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the
|
||||
Transformer encoder.
|
||||
attention_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of each attention head.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to
|
||||
implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use
|
||||
Multi Head Attention (MHA), if `num_key_value_heads=1` the model
|
||||
will use Multi Query Attention (MQA) otherwise GQA is used.
|
||||
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
|
||||
The non-linear activation function in the MLP layers.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in attention layers.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in MLP layers.
|
||||
use_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the model.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not residuals should be in `float32`. If set to `False`
|
||||
residuals will keep the same `dtype` as the rest of the model.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values
|
||||
attentions (not used by all models). Only relevant if
|
||||
`config.is_decoder=True`.
|
||||
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||
Number of prompt logits to calculate during generation. If `None`,
|
||||
all logits will be calculated. If an integer value, only last
|
||||
`num_logits_to_keep` logits will be calculated.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
sliding_window (`int`, *optional*, defaults to None):
|
||||
Sliding window attention window size.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
The maximum sequence length that this model might ever be used
|
||||
with.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the hidden states.
|
||||
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use the fast mamba kernels.
|
||||
These are available only if `mamba-ssm` and `causal-conv1d`
|
||||
are installed, and the mamba modules are running on a CUDA device.
|
||||
ssm_state_size (`int`, *optional*, defaults to 128):
|
||||
The dimension of the mamba state space latents.
|
||||
mamba_num_heads (`int`, *optional*, defaults to 128):
|
||||
Number of heads in Mamba layers.
|
||||
mamba_n_groups (`int`, *optional*, defaults to 8):
|
||||
Number of groups in Mamba layers.
|
||||
mamba_head_dim (`int`, *optional*, defaults to 64):
|
||||
Dimension of each Mamba head.
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel.
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor used to determine the mamba intermediate size.
|
||||
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
|
||||
The non-linear activation function in the Mamba layers.
|
||||
mamba_dt_min (`float`, *optional*, defaults to 0.001):
|
||||
Minimum value for the time step in Mamba.
|
||||
mamba_dt_max (`float`, *optional*, defaults to 0.1):
|
||||
Maximum value for the time step in Mamba.
|
||||
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
|
||||
Limits for the time step in Mamba.
|
||||
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
|
||||
Floor value for time step initialization in Mamba.
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the convolution layer of the mamba mixer
|
||||
block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the input and output projections of the
|
||||
mamba mixer block.
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
Size of chunks for Mamba processing.
|
||||
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the pre-normalization residual connections.
|
||||
"""
|
||||
|
||||
model_type = "nemotron_h"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=131072,
|
||||
tie_word_embeddings=False,
|
||||
hidden_size=4096,
|
||||
intermediate_size=21504,
|
||||
num_hidden_layers=52,
|
||||
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
|
||||
num_attention_heads=32,
|
||||
head_dim=128,
|
||||
num_key_value_heads=8, # nemo: num_query_groups
|
||||
mlp_hidden_act="relu2",
|
||||
attention_bias=False,
|
||||
mlp_bias=False,
|
||||
use_bias=False,
|
||||
initializer_range=0.02, # nemo: init_method_std
|
||||
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
|
||||
residual_in_fp32=False, # Megatron Core default value
|
||||
use_cache=True,
|
||||
num_logits_to_keep=1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
sliding_window=None,
|
||||
max_position_embeddings=4096,
|
||||
attention_dropout=0.0,
|
||||
hidden_dropout=0.0, # * ADDED
|
||||
use_mamba_kernels=True,
|
||||
ssm_state_size=128, # mamba_state_size
|
||||
mamba_num_heads=128,
|
||||
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
|
||||
mamba_head_dim=64,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_hidden_act="silu",
|
||||
mamba_dt_min=0.001,
|
||||
mamba_dt_max=0.1,
|
||||
mamba_dt_limit=(0.0, float("inf")),
|
||||
mamba_dt_init_floor=1e-4,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
mamba_chunk_size=256,
|
||||
rescale_prenorm_residual=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hybrid_override_pattern = hybrid_override_pattern
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.sliding_window = sliding_window
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_dropout = hidden_dropout
|
||||
|
||||
# Validate hybrid_override_pattern
|
||||
# M: Mamba2, *: Attention, -: MLP
|
||||
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
|
||||
"hybrid_override_pattern must have same length as " "num_hidden_layers"
|
||||
)
|
||||
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
|
||||
"hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
|
||||
)
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.mlp_hidden_act = mlp_hidden_act
|
||||
self.attention_bias = attention_bias
|
||||
self.mlp_bias = mlp_bias
|
||||
self.use_bias = use_bias
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
|
||||
self.use_mamba_kernels = use_mamba_kernels
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_head_dim = mamba_head_dim
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.mamba_num_heads = mamba_num_heads
|
||||
self.conv_kernel = mamba_d_conv
|
||||
self.expand = mamba_expand
|
||||
self.mamba_hidden_act = mamba_hidden_act
|
||||
self.time_step_min = mamba_dt_min
|
||||
self.time_step_max = mamba_dt_max
|
||||
self.time_step_limit = mamba_dt_limit
|
||||
self.time_step_floor = mamba_dt_init_floor
|
||||
self.use_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.rescale_prenorm_residual = rescale_prenorm_residual
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i in range(self.num_hidden_layers)
|
||||
if self.hybrid_override_pattern[i] == MAMBA
|
||||
]
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i in range(self.num_hidden_layers)
|
||||
if self.hybrid_override_pattern[i] == ATTENTION
|
||||
]
|
||||
|
||||
@property
|
||||
def mamba2_cache_params(self) -> Mamba2CacheParams:
|
||||
shape = Mamba2StateShape.create(
|
||||
tp_world_size=get_attention_tp_size(),
|
||||
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
|
||||
n_groups=self.n_groups,
|
||||
num_heads=self.mamba_num_heads,
|
||||
head_dim=self.mamba_head_dim,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel,
|
||||
)
|
||||
|
||||
return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)
|
||||
@@ -15,14 +15,12 @@
|
||||
"""Qwen3Hybrid model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
|
||||
]
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_attention_tp_size()
|
||||
conv_dim = (
|
||||
self.linear_key_head_dim * self.linear_num_key_heads * 2
|
||||
+ self.linear_value_head_dim * self.linear_num_value_heads
|
||||
)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.linear_conv_kernel_dim - 1,
|
||||
def mamba2_cache_params(self) -> Mamba2CacheParams:
|
||||
shape = Mamba2StateShape.create(
|
||||
tp_world_size=get_attention_tp_size(),
|
||||
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
|
||||
n_groups=self.linear_num_key_heads,
|
||||
num_heads=self.linear_num_value_heads,
|
||||
head_dim=self.linear_value_head_dim,
|
||||
state_size=self.linear_key_head_dim,
|
||||
conv_kernel=self.linear_conv_kernel_dim,
|
||||
)
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(self.linear_num_value_heads, world_size),
|
||||
self.linear_key_head_dim,
|
||||
self.linear_value_head_dim,
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# evade circular imports
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
ATTENTION_BACKENDS = {}
|
||||
|
||||
|
||||
@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
|
||||
return DualChunkFlashAttentionBackend(runner)
|
||||
|
||||
|
||||
def attn_backend_wrapper(runner, full_attn_backend):
|
||||
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
|
||||
"""
|
||||
Wrapper for special models like hybrid GDN, so we don't
|
||||
need to change the code of the original attention backend.
|
||||
"""
|
||||
assert not (
|
||||
runner.is_hybrid_gdn and runner.use_mla_backend
|
||||
runner.hybrid_gdn_config is not None and runner.use_mla_backend
|
||||
), "hybrid_gdn can only be used with non-MLA models."
|
||||
|
||||
# wrap for hybrid GDN models
|
||||
if runner.is_hybrid_gdn:
|
||||
if cfg := runner.mambaish_config:
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
GDNAttnBackend,
|
||||
HybridLinearAttnBackend,
|
||||
Mamba2AttnBackend,
|
||||
)
|
||||
from sglang.srt.utils import is_blackwell, is_npu
|
||||
|
||||
if is_blackwell():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "triton"
|
||||
or runner.server_args.attention_backend == "trtllm_mha"
|
||||
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
|
||||
if is_npu():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "ascend"
|
||||
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
||||
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
MambaAttnBackend,
|
||||
)
|
||||
|
||||
linear_attn_backend = MambaAttnBackend(runner)
|
||||
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
||||
if runner.hybrid_gdn_config is not None:
|
||||
if is_blackwell():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "triton"
|
||||
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
||||
if is_npu():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "ascend"
|
||||
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
||||
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
||||
linear_attn_backend = GDNAttnBackend(runner)
|
||||
elif runner.mamba2_config is not None:
|
||||
linear_attn_backend = Mamba2AttnBackend(runner)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected hybrid GDN or NemotronH models, but got unknown model."
|
||||
)
|
||||
full_attn_layers = cfg.full_attention_layer_ids
|
||||
return HybridLinearAttnBackend(
|
||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||
)
|
||||
|
||||
@@ -181,6 +181,45 @@ def _layer_norm_fwd(
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(
|
||||
*,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
return rms_norm_gated(
|
||||
x=x,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
eps=eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(
|
||||
@@ -238,14 +261,6 @@ def layernorm_fn(
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
is_rms_norm=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
PAD_SLOT_ID,
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
||||
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
|
||||
ForwardMetadata,
|
||||
Mamba2Metadata,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.models.qwen3_next import fused_gdn_gating
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import is_cuda, is_npu
|
||||
|
||||
@@ -47,18 +54,10 @@ elif is_npu():
|
||||
causal_conv1d_update = causal_conv1d_update_npu
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
mamba_cache_indices: torch.Tensor
|
||||
|
||||
|
||||
class MambaAttnBackend(AttentionBackend):
|
||||
"""Attention backend using Mamba kernel."""
|
||||
|
||||
class MambaAttnBackendBase(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.pad_slot_id = -1 # Default pad slot id
|
||||
self.pad_slot_id = PAD_SLOT_ID
|
||||
self.device = model_runner.device
|
||||
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
||||
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
def _forward_metadata(self, forward_batch: ForwardBatch):
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
|
||||
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
|
||||
forward_batch.req_pool_indices
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
return ForwardMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
self.forward_metadata = self._forward_metadata(forward_batch)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
self.forward_metadata = self._capture_metadata(
|
||||
bs, req_pool_indices, forward_mode
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
self.forward_metadata = self._replay_metadata(
|
||||
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
assert (
|
||||
max_num_tokens % max_bs == 0
|
||||
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInput],
|
||||
def _capture_metadata(
|
||||
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
|
||||
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
return ForwardMetadata(
|
||||
query_start_loc=self.query_start_loc_list[bs - 1],
|
||||
mamba_cache_indices=self.state_indices_list[bs - 1],
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
def _replay_metadata(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
return ForwardMetadata(
|
||||
query_start_loc=self.query_start_loc_list[bs - 1],
|
||||
mamba_cache_indices=self.state_indices_list[bs - 1],
|
||||
)
|
||||
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1 # Mamba attn does not use seq lens to index kv cache
|
||||
|
||||
|
||||
class GDNAttnBackend(MambaAttnBackendBase):
|
||||
"""Attention backend using Mamba kernel."""
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
|
||||
dt_bias = kwargs["dt_bias"]
|
||||
layer_id = kwargs["layer_id"]
|
||||
|
||||
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
||||
layer_id
|
||||
)
|
||||
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
||||
conv_states = layer_cache.conv
|
||||
ssm_states = layer_cache.temporal
|
||||
query_start_loc = self.forward_metadata.query_start_loc
|
||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||
|
||||
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
|
||||
query_start_loc = self.forward_metadata.query_start_loc
|
||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||
|
||||
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
||||
conv_states = mamba_cache_params.conv
|
||||
ssm_states = mamba_cache_params.temporal
|
||||
if is_target_verify:
|
||||
(
|
||||
conv_states,
|
||||
ssm_states,
|
||||
intermediate_state_cache,
|
||||
intermediate_conv_window_cache,
|
||||
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
||||
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
|
||||
intermediate_state_cache = mamba_cache_params.intermediate_ssm
|
||||
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
|
||||
has_initial_states = torch.ones(
|
||||
seq_len // forward_batch.spec_info.draft_token_num,
|
||||
dtype=torch.bool,
|
||||
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
|
||||
)
|
||||
conv_states_to_use = conv_states.clone()
|
||||
else:
|
||||
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
||||
layer_id
|
||||
)
|
||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||
conv_states_to_use = conv_states
|
||||
|
||||
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
|
||||
return core_attn_out
|
||||
|
||||
|
||||
class Mamba2AttnBackend(MambaAttnBackendBase):
|
||||
"""Attention backend wrapper for Mamba2Mixer kernels."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__(model_runner)
|
||||
config = model_runner.mamba2_config
|
||||
assert config is not None
|
||||
self.mamba_chunk_size = config.mamba_chunk_size
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
metadata = self._forward_metadata(forward_batch)
|
||||
self.forward_metadata = Mamba2Metadata.prepare_mixed(
|
||||
metadata.query_start_loc,
|
||||
metadata.mamba_cache_indices,
|
||||
self.mamba_chunk_size,
|
||||
forward_batch,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
|
||||
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
||||
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
metadata = self._replay_metadata(
|
||||
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
||||
)
|
||||
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
||||
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
mixer: MambaMixer2,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_id: int,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
use_triton_causal_conv: bool = False,
|
||||
):
|
||||
assert isinstance(self.forward_metadata, Mamba2Metadata)
|
||||
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
||||
return mixer.forward(
|
||||
hidden_states=hidden_states,
|
||||
output=output,
|
||||
layer_cache=layer_cache,
|
||||
metadata=self.forward_metadata,
|
||||
mup_vector=mup_vector,
|
||||
use_triton_causal_conv=use_triton_causal_conv,
|
||||
)
|
||||
|
||||
def forward_decode(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
||||
)
|
||||
|
||||
def forward_extend(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
||||
)
|
||||
|
||||
|
||||
class HybridLinearAttnBackend(AttentionBackend):
|
||||
"""Support different backends for prefill and decode."""
|
||||
"""Manages a full and linear attention backend"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_attn_backend: AttentionBackend,
|
||||
linear_attn_backend: AttentionBackend,
|
||||
linear_attn_backend: MambaAttnBackendBase,
|
||||
full_attn_layers: list[int],
|
||||
):
|
||||
self.full_attn_layers = full_attn_layers
|
||||
self.full_attn_backend = full_attn_backend
|
||||
self.linear_attn_backend = linear_attn_backend
|
||||
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
|
||||
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
):
|
||||
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
||||
if layer_id in self.full_attn_layers:
|
||||
return self.attn_backend_list[0].forward_decode(
|
||||
return self.full_attn_backend.forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
return self.attn_backend_list[1].forward_decode(
|
||||
return self.linear_attn_backend.forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
):
|
||||
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
||||
if layer_id in self.full_attn_layers:
|
||||
return self.attn_backend_list[0].forward_extend(
|
||||
return self.full_attn_backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
return self.attn_backend_list[1].forward_extend(
|
||||
return self.linear_attn_backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
||||
request_number = accepted_length.shape[0]
|
||||
|
||||
state_indices_tensor = self.attn_backend_list[
|
||||
1
|
||||
].forward_metadata.mamba_cache_indices[:request_number]
|
||||
state_indices_tensor = (
|
||||
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
|
||||
:request_number
|
||||
]
|
||||
)
|
||||
|
||||
mamba_caches = self.attn_backend_list[
|
||||
1
|
||||
].req_to_token_pool.get_mamba_params_all_layers()
|
||||
mamba_caches = (
|
||||
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
|
||||
)
|
||||
|
||||
(
|
||||
conv_states,
|
||||
ssm_states,
|
||||
intermediate_state_cache,
|
||||
intermediate_conv_window_cache,
|
||||
) = mamba_caches
|
||||
conv_states = mamba_caches.conv
|
||||
ssm_states = mamba_caches.temporal
|
||||
intermediate_state_cache = mamba_caches.intermediate_ssm
|
||||
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
|
||||
|
||||
# SSM state updates (chunked to reduce peak memory)
|
||||
valid_mask = accepted_length > 0
|
||||
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
from sgl_kernel import causal_conv1d_fwd
|
||||
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
from .causal_conv1d_triton import PAD_SLOT_ID
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
|
||||
+ (conv_state_batch_coord * stride_conv_state_seq)
|
||||
+ conv_state_token_offset * stride_conv_state_tok
|
||||
+ (idx_feats * stride_conv_state_dim)[None, :]
|
||||
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
|
||||
:, None
|
||||
]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (
|
||||
(conv_state_batch_coord < num_cache_lines)
|
||||
@@ -897,7 +899,10 @@ def causal_conv1d_update(
|
||||
stride_state_indices = (
|
||||
conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
||||
)
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
if num_accepted_tokens is not None:
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
else:
|
||||
state_len = width - 1
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.configs.mamba_utils import (
|
||||
Mamba2CacheParams,
|
||||
extra_groups_for_head_shards,
|
||||
)
|
||||
from sglang.srt.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
causal_conv1d_fn as causal_conv1d_fn_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
causal_conv1d_update as causal_conv1d_update_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
|
||||
from sglang.srt.layers.attention.mamba.ops import (
|
||||
mamba_chunk_scan_combined,
|
||||
selective_state_update,
|
||||
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.mem_cache.memory_pool import MambaPool
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
composed_weight_loader,
|
||||
sharded_weight_loader,
|
||||
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
|
||||
return loader
|
||||
|
||||
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (
|
||||
self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
)
|
||||
|
||||
|
||||
class MambaMixer2(torch.nn.Module):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_params: Mamba2CacheParams,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
chunk_size: int,
|
||||
layer_id: int,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
# cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.num_heads = num_heads = cache_params.shape.num_heads
|
||||
self.head_dim = cache_params.shape.head_dim
|
||||
|
||||
assert (
|
||||
num_heads % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide num heads."
|
||||
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
assert (
|
||||
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
|
||||
), (
|
||||
"Tensor parallel currently supported for quantized models only "
|
||||
"if tensor parallel world size divides num groups."
|
||||
)
|
||||
|
||||
self.ssm_state_size = cache_params.shape.ssm_state_size
|
||||
self.activation = activation
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
conv_kernel_size = cache_params.shape.conv_kernel
|
||||
self.intermediate_size = intermediate_size = (
|
||||
cache_params.shape.intermediate_size
|
||||
)
|
||||
self.n_groups = n_groups
|
||||
if n_groups % self.tp_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size
|
||||
)
|
||||
groups = extra_groups_for_head_shards(n_groups, self.tp_size)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||
self.conv_dim = cache_params.shape.conv_dim
|
||||
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
if n_groups % self.tp_size == 0:
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
if n_groups % self.tp_size != 0:
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
else:
|
||||
# This is the n_groups == 1 case,
|
||||
# where we need to duplicate groups if TP>1.
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
|
||||
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_cache: MambaPool.State,
|
||||
metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
use_triton_causal_conv: bool = False,
|
||||
):
|
||||
# attn_backend_list[-1] gives access to MambaAttnBackend
|
||||
mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
|
||||
attn_metadata = mamba_backend.forward_metadata
|
||||
state_indices_tensor = attn_metadata.mamba_cache_indices
|
||||
chunk_size = self.chunk_size
|
||||
# metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
state_indices_tensor = metadata.mamba_cache_indices
|
||||
conv_state = layer_cache.conv
|
||||
ssm_state = layer_cache.temporal
|
||||
|
||||
conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
|
||||
self.layer_id
|
||||
)
|
||||
|
||||
assert (
|
||||
ssm_state.size(1) == self.ssm_state_size
|
||||
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
|
||||
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
|
||||
chunk_size = self.chunk_size
|
||||
|
||||
# TODO: properly support this
|
||||
prep_initial_states = False
|
||||
query_start_loc = metadata.query_start_loc
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
num_prefills = metadata.num_prefills # request count
|
||||
num_decodes = metadata.num_decodes # token count (=request)
|
||||
num_prefill_tokens = metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
assert num_actual_tokens == projected_states.shape[0]
|
||||
|
||||
# NOTE: V0 put prefill before decode
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
projected_states.shape[0],
|
||||
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
if has_prefill:
|
||||
mixed_metadata = metadata.mixed_metadata
|
||||
assert mixed_metadata is not None
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
num_prefill_tokens = forward_batch.extend_num_tokens or 0
|
||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||
cache_indices = attn_metadata.mamba_cache_indices
|
||||
|
||||
x = hidden_states_B_C.transpose(
|
||||
has_initial_states_p = mixed_metadata.has_initial_states
|
||||
prep_initial_states = mixed_metadata.prep_initial_states
|
||||
cache_indices = state_indices_tensor_p
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1
|
||||
) # this is the form that causal-conv see
|
||||
hidden_states_B_C = causal_conv1d_fn(
|
||||
ccfn = (
|
||||
causal_conv1d_fn
|
||||
if not use_triton_causal_conv
|
||||
else causal_conv1d_fn_triton
|
||||
)
|
||||
hidden_states_B_C_p = ccfn(
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
).transpose(0, 1)
|
||||
query_start_loc=query_start_loc_p,
|
||||
seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
|
||||
if has_initial_states is not None and prep_initial_states:
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
initial_states = torch.where(
|
||||
has_initial_states[:, None, None, None],
|
||||
ssm_state[state_indices_tensor],
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p],
|
||||
0,
|
||||
)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states.view(
|
||||
hidden_states_p.view(
|
||||
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
||||
),
|
||||
dt.unsqueeze(0),
|
||||
dt_p.unsqueeze(0),
|
||||
self.A,
|
||||
B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=chunk_size,
|
||||
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=mixed_metadata.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
cu_seqlens=query_start_loc,
|
||||
seq_idx=mixed_metadata.seq_idx,
|
||||
chunk_indices=mixed_metadata.chunk_indices,
|
||||
chunk_offsets=mixed_metadata.chunk_offsets,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
|
||||
out=preallocated_ssm_out_p.view(
|
||||
1, num_prefill_tokens, -1, self.head_dim
|
||||
),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
num_decodes = len(query_start_loc) - 1
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C = causal_conv1d_update(
|
||||
hidden_states_B_C,
|
||||
ccu = (
|
||||
causal_conv1d_update
|
||||
if not use_triton_causal_conv
|
||||
else causal_conv1d_update_triton
|
||||
)
|
||||
hidden_states_B_C_d = ccu(
|
||||
hidden_states_B_C_d,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
)
|
||||
|
||||
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A = (
|
||||
A_d = (
|
||||
self.A[:, None, ...][:, :, None]
|
||||
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B = B.view(-1, n_groups, B.shape[1] // n_groups)
|
||||
C = C.view(-1, n_groups, C.shape[1] // n_groups)
|
||||
hidden_states = hidden_states.view(
|
||||
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
||||
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim
|
||||
)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# - layer_state.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
selective_state_update(
|
||||
ssm_state.permute(0, 3, 2, 1),
|
||||
hidden_states,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
B_d,
|
||||
C_d,
|
||||
D_d,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor,
|
||||
out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
preallocated_ssm_out = preallocated_ssm_out
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate)
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:], _ = self.out_proj(hidden_states)
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
|
||||
211
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
Normal file
211
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ForwardMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
mamba_cache_indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Mamba2Metadata(ForwardMetadata):
|
||||
"""stable metadata across all mamba2 layers in the forward pass"""
|
||||
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class MixedMetadata:
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
|
||||
chunk_size: int
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
extend_seq_lens_cpu: list[int]
|
||||
|
||||
mixed_metadata: MixedMetadata | None = None
|
||||
"""`mixed_metadata` is used for extend/mixed requests"""
|
||||
|
||||
@staticmethod
|
||||
def _query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
||||
lengths, shape (num_seqs + 1,).
|
||||
The first element should be 0. Each entry represents the starting
|
||||
index of a sequence in the flattened token array.
|
||||
chunk_size (int): The size of each physical mamba chunk
|
||||
(number of tokens per chunk).
|
||||
total_seqlens (int): The total number of tokens in the batch.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- chunk_indices (torch.Tensor): 1D tensor of indices
|
||||
indicating the physical chunk for each logical chunk.
|
||||
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
||||
indicating the starting index of each logical chunk within
|
||||
its physical chunk.
|
||||
|
||||
This function computes the chunk indices and offsets for the given
|
||||
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
||||
where N is the number of logical (pseudo) chunks.
|
||||
A logical chunk is a sequence of tokens that are all part of the same
|
||||
sequence and are all in the same physical mamba chunk.
|
||||
In other words, a logical chunk changes every time we cross a sequence
|
||||
boundary or a physical mamba chunk boundary.
|
||||
Logical chunks are needed to handle batched requests with initial states
|
||||
(see _state_passing_fwd and _chunk_scan_fwd).
|
||||
The chunk_indices tensor contains the index of the physical chunk for each
|
||||
logical chunk.
|
||||
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
||||
logical chunk in the physical chunk.
|
||||
|
||||
Example:
|
||||
query_start_loc = [0, 5, 10]
|
||||
chunk_size = 8
|
||||
total_seqlens = 10
|
||||
-> chunk_indices = [0, 0, 1]
|
||||
-> chunk_offsets = [0, 5, 0]
|
||||
|
||||
In this example, we have 2 sequences, each with 5 tokens. The physical
|
||||
chunk size is 8 tokens.
|
||||
We have three logical chunks:
|
||||
- the first logical chunk starts at token 0 in the first physical chunk
|
||||
and contains all 5 tokens from the first sequence
|
||||
- the second logical chunk starts at token 5 in the first physical chunk
|
||||
and contains first 3 tokens from the second sequence
|
||||
- the third logical chunk starts at token 0 in the second physical chunk
|
||||
and contains the remaining 2 tokens from the second sequence
|
||||
"""
|
||||
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = (
|
||||
math.ceil(total_seqlens / chunk_size)
|
||||
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
|
||||
)
|
||||
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
|
||||
chunk_offsets = torch.zeros(
|
||||
(N,), dtype=torch.int, device=query_start_loc.device
|
||||
)
|
||||
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
# if does not divide chunk_size, then there is one chunk insertion
|
||||
p += s % chunk_size > 0
|
||||
|
||||
# get the dimensions
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
|
||||
|
||||
# adjust indices and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
chunk_offsets[_s] = s % chunk_size
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
@staticmethod
|
||||
def prepare_decode(
|
||||
query_start_loc: torch.Tensor,
|
||||
mamba_cache_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
) -> "Mamba2Metadata":
|
||||
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
|
||||
return Mamba2Metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
num_decodes=len(seq_lens),
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prepare_mixed(
|
||||
cls,
|
||||
query_start_loc: torch.Tensor,
|
||||
mamba_cache_indices: torch.Tensor,
|
||||
chunk_size: int,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> "Mamba2Metadata":
|
||||
"""This path cannot run with CUDA graph, as it contains extend requests."""
|
||||
if forward_batch.extend_num_tokens is None:
|
||||
return cls.prepare_decode(
|
||||
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
|
||||
)
|
||||
num_prefills = len(forward_batch.extend_seq_lens)
|
||||
num_prefill_tokens = forward_batch.extend_num_tokens
|
||||
num_decodes = len(forward_batch.seq_lens) - num_prefills
|
||||
context_lens_tensor = forward_batch.extend_prefix_lens
|
||||
assert context_lens_tensor is not None
|
||||
# precompute flag to avoid device syncs later
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
|
||||
|
||||
query_start_loc = query_start_loc[: num_prefills + 1]
|
||||
seq_idx = torch.repeat_interleave(
|
||||
torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc.device
|
||||
),
|
||||
query_start_loc.diff(),
|
||||
output_size=num_prefill_tokens,
|
||||
)
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level model
|
||||
# forward and reuse them in mamba layers. If not needed, they will be
|
||||
# ignored inside mamba kernels.
|
||||
chunk_offsets, chunk_indices = None, None
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = (
|
||||
cls._query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc, chunk_size, num_prefill_tokens
|
||||
)
|
||||
)
|
||||
|
||||
return Mamba2Metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
mixed_metadata=cls.MixedMetadata(
|
||||
has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
),
|
||||
)
|
||||
@@ -1,81 +0,0 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
|
||||
from sglang.srt.distributed.utils import divide
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape,)
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
@@ -0,0 +1,120 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
|
||||
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
|
||||
from sglang.srt.utils.common import set_weight_attrs
|
||||
|
||||
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (
|
||||
self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * torch.nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return rms_norm_gated(
|
||||
x=x,
|
||||
weight=self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
@@ -15,56 +15,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["chunk_size", "K", "IS_CAUSAL"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -16,66 +16,6 @@ from packaging import version
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -17,17 +17,6 @@ import triton.language as tl
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_SIZE_H": 2}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 4}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 8}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 16}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 32}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 64}),
|
||||
# ],
|
||||
# key=["chunk_size", "nheads"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
|
||||
)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["hdim", "dstate", "chunk_size"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["hdim", "dstate", "chunk_size"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -13,17 +13,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_SIZE": 64}),
|
||||
# triton.Config({"BLOCK_SIZE": 128}),
|
||||
# triton.Config({"BLOCK_SIZE": 256}),
|
||||
# triton.Config({"BLOCK_SIZE": 512}),
|
||||
# triton.Config({"BLOCK_SIZE": 1024}),
|
||||
# triton.Config({"BLOCK_SIZE": 2048}),
|
||||
# ],
|
||||
# key=["dim"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
if model_runner.is_hybrid_gdn:
|
||||
if model_runner.hybrid_gdn_config is not None:
|
||||
# For hybrid linear models, layer_id = 0 may not be full attention
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
||||
else:
|
||||
|
||||
@@ -1770,7 +1770,7 @@ class Scheduler(
|
||||
chunked_req_to_exclude.add(self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
||||
if self.tp_worker.worker.model_runner.mambaish_config is not None:
|
||||
self.req_to_token_pool.free(
|
||||
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
||||
)
|
||||
|
||||
@@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
|
||||
from sglang.srt.layers.attention.nsa import index_buf_accessor
|
||||
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
@@ -109,17 +112,38 @@ class ReqToTokenPool:
|
||||
|
||||
|
||||
class MambaPool:
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class State:
|
||||
conv: torch.Tensor
|
||||
temporal: torch.Tensor
|
||||
|
||||
def at_layer_idx(self, layer: int):
|
||||
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
|
||||
|
||||
def mem_usage_bytes(self):
|
||||
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
|
||||
|
||||
@dataclass(frozen=True, kw_only=True)
|
||||
class SpeculativeState(State):
|
||||
intermediate_ssm: torch.Tensor
|
||||
intermediate_conv_window: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
size: int,
|
||||
conv_dtype: torch.dtype,
|
||||
ssm_dtype: torch.dtype,
|
||||
num_mamba_layers: int,
|
||||
conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int],
|
||||
cache_params: "Mamba2CacheParams",
|
||||
device: str,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
):
|
||||
conv_state_shape = cache_params.shape.conv
|
||||
temporal_state_shape = cache_params.shape.temporal
|
||||
conv_dtype = cache_params.dtype.conv
|
||||
ssm_dtype = cache_params.dtype.temporal
|
||||
num_mamba_layers = len(cache_params.layers)
|
||||
|
||||
# assume conv_state = (dim, state_len)
|
||||
assert conv_state_shape[0] > conv_state_shape[1]
|
||||
conv_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
||||
dtype=conv_dtype,
|
||||
@@ -158,11 +182,11 @@ class MambaPool:
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.mamba_cache = (
|
||||
conv_state,
|
||||
temporal_state,
|
||||
intermediate_ssm_state_cache,
|
||||
intermediate_conv_window_cache,
|
||||
self.mamba_cache = self.SpeculativeState(
|
||||
conv=conv_state,
|
||||
temporal=temporal_state,
|
||||
intermediate_ssm=intermediate_ssm_state_cache,
|
||||
intermediate_conv_window=intermediate_conv_window_cache,
|
||||
)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
@@ -172,7 +196,7 @@ class MambaPool:
|
||||
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
||||
@@ -180,16 +204,14 @@ class MambaPool:
|
||||
)
|
||||
self.size = size
|
||||
self.free_slots = list(range(size))
|
||||
self.mem_usage = self.get_mamba_size() / GB
|
||||
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
|
||||
|
||||
def get_mamba_params_all_layers(self):
|
||||
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
||||
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
|
||||
assert isinstance(self.mamba_cache, self.SpeculativeState)
|
||||
return self.mamba_cache
|
||||
|
||||
def get_mamba_params(self, layer_id: int):
|
||||
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
||||
|
||||
def get_mamba_size(self):
|
||||
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
|
||||
def mamba2_layer_cache(self, layer_id: int):
|
||||
return self.mamba_cache.at_layer_idx(layer_id)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
@@ -208,7 +230,9 @@ class MambaPool:
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
||||
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
|
||||
:, free_index
|
||||
] = 0
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
@@ -219,16 +243,13 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
conv_dtype: torch.dtype,
|
||||
ssm_dtype: torch.dtype,
|
||||
mamba_layers: List[int],
|
||||
conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int],
|
||||
speculative_num_draft_tokens: int,
|
||||
cache_params: "Mamba2CacheParams",
|
||||
speculative_num_draft_tokens: int = None,
|
||||
):
|
||||
super().__init__(
|
||||
size=size,
|
||||
@@ -238,16 +259,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
)
|
||||
|
||||
self.mamba_pool = MambaPool(
|
||||
size,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
len(mamba_layers),
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
device,
|
||||
speculative_num_draft_tokens,
|
||||
size=size,
|
||||
cache_params=cache_params,
|
||||
device=device,
|
||||
speculative_num_draft_tokens=speculative_num_draft_tokens,
|
||||
)
|
||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
|
||||
|
||||
self.device = device
|
||||
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
||||
@@ -287,12 +304,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
|
||||
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
||||
return self.req_index_to_mamba_index_mapping[req_indices]
|
||||
|
||||
def get_mamba_params(self, layer_id: int):
|
||||
def mamba2_layer_cache(self, layer_id: int):
|
||||
assert layer_id in self.mamba_map
|
||||
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
||||
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
|
||||
|
||||
def get_mamba_params_all_layers(self):
|
||||
return self.mamba_pool.get_mamba_params_all_layers()
|
||||
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
|
||||
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
|
||||
|
||||
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
||||
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
||||
|
||||
@@ -29,6 +29,7 @@ from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
|
||||
from sglang.srt.configs.device_config import DeviceConfig
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
from sglang.srt.configs.model_config import (
|
||||
@@ -354,8 +355,9 @@ class ModelRunner:
|
||||
if architectures and not any("Llama4" in arch for arch in architectures):
|
||||
self.is_hybrid = self.model_config.is_hybrid = True
|
||||
|
||||
if self.is_hybrid_gdn:
|
||||
logger.warning("Hybrid GDN model detected, disable radix cache")
|
||||
if config := self.mambaish_config:
|
||||
class_name = config.__class__.__name__
|
||||
logger.warning(f"{class_name} model detected, disable radix cache")
|
||||
self.server_args.disable_radix_cache = True
|
||||
if self.server_args.max_mamba_cache_size is None:
|
||||
if self.server_args.max_running_requests is not None:
|
||||
@@ -364,6 +366,7 @@ class ModelRunner:
|
||||
)
|
||||
else:
|
||||
self.server_args.max_mamba_cache_size = 512
|
||||
if self.hybrid_gdn_config is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_mamba_cache_size
|
||||
// (
|
||||
@@ -1267,8 +1270,8 @@ class ModelRunner:
|
||||
"num_nextn_predict_layers",
|
||||
self.num_effective_layers,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
||||
elif config := self.mambaish_config:
|
||||
num_layers = len(config.full_attention_layer_ids)
|
||||
else:
|
||||
num_layers = self.num_effective_layers
|
||||
if self.use_mla_backend:
|
||||
@@ -1288,22 +1291,32 @@ class ModelRunner:
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
if self.is_hybrid_gdn:
|
||||
if config := self.mambaish_config:
|
||||
rest_memory -= (
|
||||
self.server_args.max_mamba_cache_size
|
||||
* self.model_config.hf_config.mamba_cache_per_req
|
||||
* config.mamba2_cache_params.mamba_cache_per_req
|
||||
/ (1 << 30)
|
||||
)
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
@property
|
||||
def is_hybrid_gdn(self):
|
||||
return self.model_config.hf_config.architectures[0] in [
|
||||
"Qwen3NextForCausalLM",
|
||||
"Qwen3NextForCausalLMMTP",
|
||||
"FalconH1ForCausalLM",
|
||||
]
|
||||
def hybrid_gdn_config(self):
|
||||
config = self.model_config.hf_config
|
||||
if isinstance(config, Qwen3NextConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
@property
|
||||
def mamba2_config(self):
|
||||
config = self.model_config.hf_config
|
||||
if isinstance(config, FalconH1Config | NemotronHConfig):
|
||||
return config
|
||||
return None
|
||||
|
||||
@property
|
||||
def mambaish_config(self):
|
||||
return self.mamba2_config or self.hybrid_gdn_config
|
||||
|
||||
def set_num_token_hybrid(self):
|
||||
if (
|
||||
@@ -1438,7 +1451,7 @@ class ModelRunner:
|
||||
),
|
||||
4096,
|
||||
)
|
||||
if self.is_hybrid_gdn:
|
||||
if self.mambaish_config is not None:
|
||||
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
||||
|
||||
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||
@@ -1519,26 +1532,14 @@ class ModelRunner:
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
config = self.model_config.hf_config
|
||||
(
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
) = config.hybrid_gdn_params
|
||||
elif config := self.mambaish_config:
|
||||
self.req_to_token_pool = HybridReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
conv_state_shape=conv_state_shape,
|
||||
temporal_state_shape=temporal_state_shape,
|
||||
conv_dtype=conv_dtype,
|
||||
ssm_dtype=ssm_dtype,
|
||||
mamba_layers=mamba_layers,
|
||||
cache_params=config.mamba2_cache_params,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
else:
|
||||
@@ -1640,7 +1641,7 @@ class ModelRunner:
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
elif config := self.mambaish_config:
|
||||
self.token_to_kv_pool = HybridLinearKVPool(
|
||||
page_size=self.page_size,
|
||||
size=self.max_total_num_tokens,
|
||||
@@ -1651,9 +1652,7 @@ class ModelRunner:
|
||||
head_dim=self.model_config.head_dim,
|
||||
# if draft worker, we only need 1 attention layer's kv pool
|
||||
full_attention_layer_ids=(
|
||||
[0]
|
||||
if self.is_draft_worker
|
||||
else self.model_config.hf_config.full_attention_layer_ids
|
||||
[0] if self.is_draft_worker else config.full_attention_layer_ids
|
||||
),
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
@@ -1681,7 +1680,8 @@ class ModelRunner:
|
||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
if _is_npu and (
|
||||
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
|
||||
self.server_args.attention_backend == "ascend"
|
||||
or self.hybrid_gdn_config is not None
|
||||
):
|
||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
|
||||
@@ -8,6 +8,10 @@ from torch import nn
|
||||
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
||||
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
Mamba2AttnBackend,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
@@ -184,18 +188,12 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
self.mamba = MambaMixer2(
|
||||
cache_params=config.mamba2_cache_params,
|
||||
hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.mamba_d_state,
|
||||
conv_kernel_size=config.mamba_d_conv,
|
||||
intermediate_size=self.d_ssm,
|
||||
use_conv_bias=config.mamba_conv_bias,
|
||||
use_bias=config.mamba_proj_bias,
|
||||
n_groups=config.mamba_n_groups,
|
||||
num_heads=config.mamba_n_heads,
|
||||
layer_id=layer_id,
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
chunk_size=config.mamba_chunk_size,
|
||||
activation=config.hidden_act,
|
||||
use_rms_norm=config.mamba_rms_norm,
|
||||
prefix=f"{prefix}.mixer",
|
||||
@@ -339,12 +337,16 @@ class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
||||
)
|
||||
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
||||
|
||||
attn_backend = forward_batch.attn_backend
|
||||
assert isinstance(attn_backend, HybridLinearAttnBackend)
|
||||
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
|
||||
# Mamba block
|
||||
mamba_hidden_states = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
attn_backend.linear_attn_backend.forward(
|
||||
self.mamba,
|
||||
hidden_states * self.ssm_in_multiplier,
|
||||
mamba_hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
layer_id=self.layer_id,
|
||||
mup_vector=self.mup_vector,
|
||||
)
|
||||
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
||||
|
||||
514
python/sglang/srt/models/nemotron_h.py
Normal file
514
python/sglang/srt/models/nemotron_h.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# Copyright 2023-2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_h.py
|
||||
|
||||
"""Inference-only NemotronH model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs import NemotronHConfig
|
||||
from sglang.srt.configs.nemotron_h import ATTENTION, MAMBA, MLP
|
||||
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import ReLU2
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
Mamba2AttnBackend,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE,
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.utils import add_prefix, make_layers_non_pp
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class NemotronHMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
hybrid_override_pattern = config.hybrid_override_pattern
|
||||
mlp_index = hybrid_override_pattern[: layer_idx + 1].count("-") - 1
|
||||
if isinstance(config.intermediate_size, list):
|
||||
if len(config.intermediate_size) == 1:
|
||||
intermediate_size = config.intermediate_size[0]
|
||||
else:
|
||||
intermediate_size = config.intermediate_size[mlp_index]
|
||||
else:
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
input_size=config.hidden_size,
|
||||
output_size=intermediate_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.up_proj",
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
input_size=intermediate_size,
|
||||
output_size=config.hidden_size,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
)
|
||||
self.act_fn = ReLU2()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x, _ = self.up_proj(x)
|
||||
x = self.act_fn(x)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class NemotronHMLPDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.mixer = NemotronHMLP(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
bias=config.mlp_bias,
|
||||
prefix=f"{prefix}.mixer",
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
forward_batch: ForwardBatch,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer.forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class NemotronHMambaDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_id = layer_idx
|
||||
self.mixer = MambaMixer2(
|
||||
cache_params=config.mamba2_cache_params,
|
||||
hidden_size=config.hidden_size,
|
||||
use_conv_bias=config.use_conv_bias,
|
||||
use_bias=config.use_bias,
|
||||
n_groups=config.mamba_n_groups,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
activation=config.mamba_hidden_act,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
forward_batch: ForwardBatch,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
output = torch.empty_like(hidden_states)
|
||||
attn_backend = forward_batch.attn_backend
|
||||
assert isinstance(attn_backend, HybridLinearAttnBackend)
|
||||
assert isinstance(attn_backend.linear_attn_backend, Mamba2AttnBackend)
|
||||
attn_backend.linear_attn_backend.forward(
|
||||
mixer=self.mixer,
|
||||
layer_id=self.layer_id,
|
||||
hidden_states=hidden_states,
|
||||
output=output,
|
||||
use_triton_causal_conv=True, # TODO: investigate need of `use_triton_causal_conv`
|
||||
)
|
||||
return output, residual
|
||||
|
||||
|
||||
class NemotronHAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
if hasattr(config, "head_dim") and config.head_dim is not None:
|
||||
self.head_dim = config.head_dim
|
||||
else:
|
||||
self.head_dim = config.hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_idx,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("attn", prefix),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
attn_output = self.attn.forward(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class NemotronHAttentionDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
layer_idx: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.mixer = NemotronHAttention(
|
||||
config,
|
||||
layer_idx,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
forward_batch: ForwardBatch,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer.forward(
|
||||
hidden_states=hidden_states, forward_batch=forward_batch
|
||||
)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
Layers = (
|
||||
NemotronHAttentionDecoderLayer
|
||||
| NemotronHMLPDecoderLayer
|
||||
| NemotronHMambaDecoderLayer
|
||||
)
|
||||
ALL_DECODER_LAYER_TYPES: dict[str, type[Layers]] = {
|
||||
ATTENTION: NemotronHAttentionDecoderLayer,
|
||||
MLP: NemotronHMLPDecoderLayer,
|
||||
MAMBA: NemotronHMambaDecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
class NemotronHModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: NemotronHConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
lora_config = None
|
||||
self.config = config
|
||||
lora_vocab = (
|
||||
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
||||
if lora_config
|
||||
else 0
|
||||
)
|
||||
self.vocab_size = config.vocab_size + lora_vocab
|
||||
self.org_vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
self.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
)
|
||||
|
||||
def get_layer(idx: int, prefix: str):
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[config.hybrid_override_pattern[idx]]
|
||||
return layer_class(config, idx, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
self.layers = make_layers_non_pp(
|
||||
len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
|
||||
)
|
||||
self.norm_f = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.get_input_embeddings(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert pp_proxy_tensors is not None
|
||||
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||
residual = pp_proxy_tensors["residual"]
|
||||
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
if not isinstance(layer, Layers):
|
||||
raise ValueError(f"Unknown layer type: {type(layer)}")
|
||||
hidden_states, residual = layer.forward(
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return PPProxyTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm_f(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class NemotronHForCausalLM(nn.Module):
|
||||
remap_prefix = {"backbone": "model"}
|
||||
remap_substr = {"A_log": "A", "embeddings": "embed_tokens"}
|
||||
|
||||
# LoRA specific attributes
|
||||
embedding_modules = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
embedding_padding_modules = ["lm_head"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
config: NemotronHConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
lora_config = None
|
||||
self.config = config
|
||||
self.model = self._init_model(
|
||||
config=config, quant_config=quant_config, prefix=prefix
|
||||
)
|
||||
if self.config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.unpadded_vocab_size = config.vocab_size
|
||||
if lora_config:
|
||||
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=(
|
||||
DEFAULT_VOCAB_PADDING_SIZE
|
||||
# We need bigger padding if using lora for kernel
|
||||
# compatibility
|
||||
if not lora_config
|
||||
else lora_config.lora_vocab_padding_size
|
||||
),
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def _init_model(
|
||||
self,
|
||||
config: NemotronHConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
return NemotronHModel(config=config, quant_config=quant_config, prefix=prefix)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||
):
|
||||
hidden_states = self.model.forward(
|
||||
input_ids, positions, forward_batch, pp_proxy_tensors, input_embeds
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
|
||||
return self.mamba_cache.copy_inputs_before_cuda_graphs(input_buffers, **kwargs)
|
||||
|
||||
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
|
||||
return self.mamba_cache.get_seqlen_agnostic_capture_inputs(batch_size)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
updated_weights = []
|
||||
for name, loaded_weight in weights:
|
||||
for prefix, new_key in self.remap_prefix.items():
|
||||
if name.startswith(prefix):
|
||||
name = name.replace(prefix, new_key)
|
||||
for substr, new_key in self.remap_substr.items():
|
||||
if substr in name:
|
||||
name = name.replace(substr, new_key)
|
||||
updated_weights.append((name, loaded_weight))
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in updated_weights:
|
||||
if "scale" in name:
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
if name in params_dict.keys():
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
else:
|
||||
logger.warning(f"Parameter {name} not found in params_dict")
|
||||
|
||||
|
||||
EntryClass = [NemotronHForCausalLM]
|
||||
@@ -866,7 +866,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
||||
|
||||
# QQ: can be optimized
|
||||
if self.target_worker.model_runner.is_hybrid_gdn:
|
||||
if self.target_worker.model_runner.hybrid_gdn_config is not None:
|
||||
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
||||
accepted_length = (
|
||||
torch.tensor(
|
||||
|
||||
@@ -518,6 +518,24 @@ def make_layers(
|
||||
return modules, start_layer, end_layer
|
||||
|
||||
|
||||
def make_layers_non_pp(
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
prefix: str = "",
|
||||
) -> torch.nn.ModuleList:
|
||||
from sglang.srt.offloader import get_offloader
|
||||
|
||||
layers = torch.nn.ModuleList(
|
||||
get_offloader().wrap_modules(
|
||||
(
|
||||
layer_fn(idx=idx, prefix=add_prefix(idx, prefix))
|
||||
for idx in range(num_hidden_layers)
|
||||
)
|
||||
)
|
||||
)
|
||||
return layers
|
||||
|
||||
|
||||
cmo_stream = None
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.configs import (
|
||||
KimiVLConfig,
|
||||
LongcatFlashConfig,
|
||||
MultiModalityConfig,
|
||||
NemotronHConfig,
|
||||
Qwen3NextConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
@@ -66,6 +67,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
FalconH1Config.model_type: FalconH1Config,
|
||||
DotsVLMConfig.model_type: DotsVLMConfig,
|
||||
DotsOCRConfig.model_type: DotsOCRConfig,
|
||||
NemotronHConfig.model_type: NemotronHConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
Reference in New Issue
Block a user