model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
||||
from sglang.srt.configs.nemotron_h import NemotronHConfig
|
||||
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
@@ -32,4 +33,5 @@ __all__ = [
|
||||
"DotsVLMConfig",
|
||||
"DotsOCRConfig",
|
||||
"FalconH1Config",
|
||||
"NemotronHConfig",
|
||||
]
|
||||
|
||||
@@ -15,16 +15,12 @@
|
||||
"""Falcon-H1 model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
@@ -214,7 +210,7 @@ class FalconH1Config(PretrainedConfig):
|
||||
self.rope_scaling = None
|
||||
self.rope_scaling = rope_scaling
|
||||
self.projectors_bias = projectors_bias
|
||||
mamba_intermediate = (
|
||||
self.mamba_intermediate = mamba_intermediate = (
|
||||
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
|
||||
)
|
||||
|
||||
@@ -294,18 +290,6 @@ class FalconH1Config(PretrainedConfig):
|
||||
def layers_block_type(self):
|
||||
return ["falcon_h1" for i in range(self.num_hidden_layers)]
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
# For Falcon-H1, we do have attention on all layers
|
||||
@@ -317,44 +301,14 @@ class FalconH1Config(PretrainedConfig):
|
||||
return range(self.num_hidden_layers)
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
n_groups = self.mamba_n_groups
|
||||
if self.mamba_n_groups % world_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
self.mamba_n_groups, world_size
|
||||
)
|
||||
n_groups += extra_groups
|
||||
|
||||
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
|
||||
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.mamba_d_conv - 1,
|
||||
)
|
||||
|
||||
# we TP-ize on the heads dimension
|
||||
temporal_state_shape = (
|
||||
self.mamba_d_state,
|
||||
self.mamba_d_head,
|
||||
divide(self.mamba_n_heads, world_size),
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
def mamba2_cache_params(self):
|
||||
shape = Mamba2StateShape.create(
|
||||
tp_world_size=get_tensor_model_parallel_world_size(),
|
||||
intermediate_size=self.mamba_intermediate,
|
||||
n_groups=self.mamba_n_groups,
|
||||
num_heads=self.mamba_n_heads,
|
||||
head_dim=self.mamba_d_head,
|
||||
state_size=self.mamba_d_state,
|
||||
conv_kernel=self.mamba_d_conv,
|
||||
)
|
||||
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
|
||||
|
||||
117
python/sglang/srt/configs/mamba_utils.py
Normal file
117
python/sglang/srt/configs/mamba_utils.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
|
||||
|
||||
def extra_groups_for_head_shards(ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class Mamba2StateShape:
|
||||
conv: tuple[int, int]
|
||||
temporal: tuple[int, int, int]
|
||||
|
||||
intermediate_size: int
|
||||
conv_dim: int
|
||||
ssm_state_size: int
|
||||
num_heads: int
|
||||
head_dim: int
|
||||
state_size: int
|
||||
conv_kernel: int
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
*,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> "Mamba2StateShape":
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
if n_groups % tp_world_size != 0:
|
||||
extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
n_groups += extra_groups
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return Mamba2StateShape(
|
||||
conv=conv_state_shape,
|
||||
temporal=temporal_state_shape,
|
||||
intermediate_size=intermediate_size,
|
||||
conv_dim=conv_dim,
|
||||
ssm_state_size=state_size,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
state_size=state_size,
|
||||
conv_kernel=conv_kernel,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class Mamba2StateDType:
|
||||
conv: torch.dtype
|
||||
temporal: torch.dtype
|
||||
|
||||
|
||||
CONV_DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
def mamba2_state_dtype() -> Mamba2StateDType:
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class Mamba2CacheParams:
|
||||
shape: Mamba2StateShape
|
||||
dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
|
||||
layers: list[int]
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self) -> int:
|
||||
return (
|
||||
int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
|
||||
+ int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
|
||||
) * len(self.layers)
|
||||
286
python/sglang/srt/configs/nemotron_h.py
Normal file
286
python/sglang/srt/configs/nemotron_h.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
|
||||
|
||||
"""NemotronH model configuration"""
|
||||
|
||||
import regex as re
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAMBA = "M"
|
||||
ATTENTION = "*"
|
||||
MLP = "-"
|
||||
|
||||
|
||||
class NemotronHConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a
|
||||
[`NemotronHModel`]. It is used to instantiate a NemotronH model according
|
||||
to the specified arguments, defining the model architecture. Instantiating
|
||||
a configuration with the defaults will yield a similar configuration to
|
||||
that of the NemotronH-v0.1 model.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 131072):
|
||||
Vocabulary size of the NemotronH model. Defines the number of
|
||||
different tokens that can be represented by the `inputs_ids`
|
||||
passed when calling [`NemotronHModel`]
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be
|
||||
tied. Note that this is only relevant if the model has an output
|
||||
word embedding layer.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 21504):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 52):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
hybrid_override_pattern (`str`, *optional*, defaults to
|
||||
`"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
|
||||
The pattern of the hybrid model. The pattern is a string of
|
||||
characters where each character represents
|
||||
M: Mamba2, *: Attention, -: MLP
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the
|
||||
Transformer encoder.
|
||||
attention_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of each attention head.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to
|
||||
implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use
|
||||
Multi Head Attention (MHA), if `num_key_value_heads=1` the model
|
||||
will use Multi Query Attention (MQA) otherwise GQA is used.
|
||||
mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
|
||||
The non-linear activation function in the MLP layers.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in attention layers.
|
||||
mlp_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in MLP layers.
|
||||
use_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the model.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
|
||||
The epsilon used by the layer normalization layers.
|
||||
residual_in_fp32 (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not residuals should be in `float32`. If set to `False`
|
||||
residuals will keep the same `dtype` as the rest of the model.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values
|
||||
attentions (not used by all models). Only relevant if
|
||||
`config.is_decoder=True`.
|
||||
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||
Number of prompt logits to calculate during generation. If `None`,
|
||||
all logits will be calculated. If an integer value, only last
|
||||
`num_logits_to_keep` logits will be calculated.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
sliding_window (`int`, *optional*, defaults to None):
|
||||
Sliding window attention window size.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 4096):
|
||||
The maximum sequence length that this model might ever be used
|
||||
with.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
hidden_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the hidden states.
|
||||
use_mamba_kernels (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use the fast mamba kernels.
|
||||
These are available only if `mamba-ssm` and `causal-conv1d`
|
||||
are installed, and the mamba modules are running on a CUDA device.
|
||||
ssm_state_size (`int`, *optional*, defaults to 128):
|
||||
The dimension of the mamba state space latents.
|
||||
mamba_num_heads (`int`, *optional*, defaults to 128):
|
||||
Number of heads in Mamba layers.
|
||||
mamba_n_groups (`int`, *optional*, defaults to 8):
|
||||
Number of groups in Mamba layers.
|
||||
mamba_head_dim (`int`, *optional*, defaults to 64):
|
||||
Dimension of each Mamba head.
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel.
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor used to determine the mamba intermediate size.
|
||||
mamba_hidden_act (`str`, *optional*, defaults to "silu"):
|
||||
The non-linear activation function in the Mamba layers.
|
||||
mamba_dt_min (`float`, *optional*, defaults to 0.001):
|
||||
Minimum value for the time step in Mamba.
|
||||
mamba_dt_max (`float`, *optional*, defaults to 0.1):
|
||||
Maximum value for the time step in Mamba.
|
||||
mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
|
||||
Limits for the time step in Mamba.
|
||||
mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
|
||||
Floor value for time step initialization in Mamba.
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use bias in the convolution layer of the mamba mixer
|
||||
block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the input and output projections of the
|
||||
mamba mixer block.
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
Size of chunks for Mamba processing.
|
||||
rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
|
||||
Whether to rescale the pre-normalization residual connections.
|
||||
"""
|
||||
|
||||
model_type = "nemotron_h"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=131072,
|
||||
tie_word_embeddings=False,
|
||||
hidden_size=4096,
|
||||
intermediate_size=21504,
|
||||
num_hidden_layers=52,
|
||||
hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
|
||||
num_attention_heads=32,
|
||||
head_dim=128,
|
||||
num_key_value_heads=8, # nemo: num_query_groups
|
||||
mlp_hidden_act="relu2",
|
||||
attention_bias=False,
|
||||
mlp_bias=False,
|
||||
use_bias=False,
|
||||
initializer_range=0.02, # nemo: init_method_std
|
||||
layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
|
||||
residual_in_fp32=False, # Megatron Core default value
|
||||
use_cache=True,
|
||||
num_logits_to_keep=1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
sliding_window=None,
|
||||
max_position_embeddings=4096,
|
||||
attention_dropout=0.0,
|
||||
hidden_dropout=0.0, # * ADDED
|
||||
use_mamba_kernels=True,
|
||||
ssm_state_size=128, # mamba_state_size
|
||||
mamba_num_heads=128,
|
||||
mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
|
||||
mamba_head_dim=64,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_hidden_act="silu",
|
||||
mamba_dt_min=0.001,
|
||||
mamba_dt_max=0.1,
|
||||
mamba_dt_limit=(0.0, float("inf")),
|
||||
mamba_dt_init_floor=1e-4,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
mamba_chunk_size=256,
|
||||
rescale_prenorm_residual=True,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.hybrid_override_pattern = hybrid_override_pattern
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.head_dim = head_dim
|
||||
self.sliding_window = sliding_window
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.attention_dropout = attention_dropout
|
||||
self.hidden_dropout = hidden_dropout
|
||||
|
||||
# Validate hybrid_override_pattern
|
||||
# M: Mamba2, *: Attention, -: MLP
|
||||
assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
|
||||
"hybrid_override_pattern must have same length as " "num_hidden_layers"
|
||||
)
|
||||
assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
|
||||
"hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
|
||||
)
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.mlp_hidden_act = mlp_hidden_act
|
||||
self.attention_bias = attention_bias
|
||||
self.mlp_bias = mlp_bias
|
||||
self.use_bias = use_bias
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.residual_in_fp32 = residual_in_fp32
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
|
||||
self.use_mamba_kernels = use_mamba_kernels
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_head_dim = mamba_head_dim
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.mamba_num_heads = mamba_num_heads
|
||||
self.conv_kernel = mamba_d_conv
|
||||
self.expand = mamba_expand
|
||||
self.mamba_hidden_act = mamba_hidden_act
|
||||
self.time_step_min = mamba_dt_min
|
||||
self.time_step_max = mamba_dt_max
|
||||
self.time_step_limit = mamba_dt_limit
|
||||
self.time_step_floor = mamba_dt_init_floor
|
||||
self.use_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.rescale_prenorm_residual = rescale_prenorm_residual
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i in range(self.num_hidden_layers)
|
||||
if self.hybrid_override_pattern[i] == MAMBA
|
||||
]
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i in range(self.num_hidden_layers)
|
||||
if self.hybrid_override_pattern[i] == ATTENTION
|
||||
]
|
||||
|
||||
@property
|
||||
def mamba2_cache_params(self) -> Mamba2CacheParams:
|
||||
shape = Mamba2StateShape.create(
|
||||
tp_world_size=get_attention_tp_size(),
|
||||
intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
|
||||
n_groups=self.n_groups,
|
||||
num_heads=self.mamba_num_heads,
|
||||
head_dim=self.mamba_head_dim,
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel,
|
||||
)
|
||||
|
||||
return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)
|
||||
@@ -15,14 +15,12 @@
|
||||
"""Qwen3Hybrid model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
@@ -282,45 +280,15 @@ class Qwen3NextConfig(PretrainedConfig):
|
||||
]
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_attention_tp_size()
|
||||
conv_dim = (
|
||||
self.linear_key_head_dim * self.linear_num_key_heads * 2
|
||||
+ self.linear_value_head_dim * self.linear_num_value_heads
|
||||
)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.linear_conv_kernel_dim - 1,
|
||||
def mamba2_cache_params(self) -> Mamba2CacheParams:
|
||||
shape = Mamba2StateShape.create(
|
||||
tp_world_size=get_attention_tp_size(),
|
||||
intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
|
||||
n_groups=self.linear_num_key_heads,
|
||||
num_heads=self.linear_num_value_heads,
|
||||
head_dim=self.linear_value_head_dim,
|
||||
state_size=self.linear_key_head_dim,
|
||||
conv_kernel=self.linear_conv_kernel_dim,
|
||||
)
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(self.linear_num_value_heads, world_size),
|
||||
self.linear_key_head_dim,
|
||||
self.linear_value_head_dim,
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)
|
||||
|
||||
Reference in New Issue
Block a user