From 30c6e1f56967957615f4402b17e1ce6e15d63785 Mon Sep 17 00:00:00 2001 From: Yi Zhang <1109276519@qq.com> Date: Thu, 11 Sep 2025 19:11:49 +0800 Subject: [PATCH] Qwen3-Next support (#10233) Co-authored-by: cao1zhg <114661107+cao1zhg@users.noreply.github.com> Co-authored-by: ispobock Co-authored-by: Binyao Jiang Co-authored-by: hebiao064 Co-authored-by: Lifu Huang Co-authored-by: qingquansong Co-authored-by: Yaoyao Ding Co-authored-by: Ke Bao Co-authored-by: Minglei Zhu --- python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/model_config.py | 3 + python/sglang/srt/configs/qwen3_next.py | 326 +++++ python/sglang/srt/hf_transformers_utils.py | 2 + .../attention/hybrid_linear_attn_backend.py | 581 +++++++++ .../layers/attention/mamba/causal_conv1d.py | 128 ++ .../srt/layers/attention/mamba/mamba.py | 64 + ...128,device_name=NVIDIA_H100_80GB_HBM3.json | 146 +++ ...=64,device_name=NVIDIA_H100_80GB_HBM3.json | 146 +++ python/sglang/srt/managers/schedule_batch.py | 13 +- python/sglang/srt/managers/scheduler.py | 7 +- python/sglang/srt/mem_cache/memory_pool.py | 280 +++++ .../sglang/srt/model_executor/model_runner.py | 96 ++ .../sglang/srt/model_loader/weight_utils.py | 3 +- python/sglang/srt/models/qwen3_next.py | 1072 +++++++++++++++++ python/sglang/srt/models/qwen3_next_mtp.py | 117 ++ python/sglang/srt/server_args.py | 22 +- .../eagle_target_verify_cuda_graph_runner.py | 195 +++ python/sglang/srt/speculative/eagle_worker.py | 29 + 19 files changed, 3224 insertions(+), 8 deletions(-) create mode 100644 python/sglang/srt/configs/qwen3_next.py create mode 100644 python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py create mode 100644 python/sglang/srt/layers/attention/mamba/causal_conv1d.py create mode 100644 python/sglang/srt/layers/attention/mamba/mamba.py create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 python/sglang/srt/models/qwen3_next.py create mode 100644 python/sglang/srt/models/qwen3_next_mtp.py create mode 100644 python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 24fba32b3..ef880c911 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -6,6 +6,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig from sglang.srt.configs.longcat_flash import LongcatFlashConfig +from sglang.srt.configs.qwen3_next import Qwen3NextConfig from sglang.srt.configs.step3_vl import ( Step3TextConfig, Step3VisionEncoderConfig, @@ -24,4 +25,5 @@ __all__ = [ "Step3VLConfig", "Step3TextConfig", "Step3VisionEncoderConfig", + "Qwen3NextConfig", ] diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index fb8c2501b..f16442e4d 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -147,6 +147,9 @@ class ModelConfig: ): self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP" + if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM": + self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP" + # Check model type self.is_generation = is_generation_model( self.hf_config.architectures, is_embedding diff --git a/python/sglang/srt/configs/qwen3_next.py b/python/sglang/srt/configs/qwen3_next.py new file mode 100644 index 000000000..099d14d41 --- /dev/null +++ b/python/sglang/srt/configs/qwen3_next.py @@ -0,0 +1,326 @@ +# coding=utf-8 +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Qwen3Hybrid model configuration""" + +import enum +import os + +import numpy as np +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +from sglang.srt.distributed.utils import divide +from sglang.srt.layers.dp_attention import get_attention_tp_size + +logger = logging.get_logger(__name__) + + +# NOTE: HybridLayerType +class HybridLayerType(enum.Enum): + full_attention = "attention" + swa_attention = "swa_attention" + linear_attention = "linear_attention" + mamba2 = "mamba" + + +class Qwen3NextConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a + Qwen3-Next model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of + Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct). + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 151936): + Vocabulary size of the model. Defines the number of different tokens that can be represented by the + `inputs_ids`. + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 5632): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 48): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 16): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 32768): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type + and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value + accordingly. + Expected contents: + `rope_type` (`str`): + The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope', + 'llama3'], with 'default' being the original RoPE implementation. + `factor` (`float`, *optional*): + Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In + most scaling types, a `factor` of x will enable the model to handle sequences of length x * + original maximum pre-trained length. + `original_max_position_embeddings` (`int`, *optional*): + Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during + pretraining. + `attention_factor` (`float`, *optional*): + Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention + computation. If unspecified, it defaults to value recommended by the implementation, using the + `factor` field to infer the suggested value. + `beta_fast` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear + ramp function. If unspecified, it defaults to 32. + `beta_slow` (`float`, *optional*): + Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear + ramp function. If unspecified, it defaults to 1. + `short_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to short contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `long_factor` (`List[float]`, *optional*): + Only used with 'longrope'. The scaling factor to be applied to long contexts (< + `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden + size divided by the number of attention heads divided by 2 + `low_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE + `high_freq_factor` (`float`, *optional*): + Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + partial_rotary_factor (`float`, *optional*, defaults to 0.25): + Percentage of the query and keys which will have rotary embedding. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + head_dim (`int`, *optional*, defaults to 256): + Projection weights dimension in multi-head attention. + linear_conv_kernel_dim (`int`, *optional*, defaults to 4): + Kernel size of the convolution used in linear attention layers. + linear_key_head_dim (`int`, *optional*, defaults to 128): + Dimension of each key head in linear attention. + linear_value_head_dim (`int`, *optional*, defaults to 128): + Dimension of each value head in linear attention. + linear_num_key_heads (`int`, *optional*, defaults to 16): + Number of key heads used in linear attention layers. + linear_num_value_heads (`int`, *optional*, defaults to 32): + Number of value heads used in linear attention layers. + decoder_sparse_step (`int`, *optional*, defaults to 1): + The frequency of the MoE layer. + moe_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the routed expert. + shared_expert_intermediate_size (`int`, *optional*, defaults to 512): + Intermediate size of the shared expert. + num_experts_per_tok (`int`, *optional*, defaults to 10): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 512): + Number of routed experts. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss, including load balancing loss and router z-loss. + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + mlp_only_layers (`list[int]`, *optional*, defaults to `[]`): + Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock + The list contains layer index, from 0 to num_layers-1 if we have num_layers layers + If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity. + layer_types (`list[str]`, *optional*, defaults to None): + Types of each layer (attention or linear). + + ```python + >>> from transformers import Qwen3NextModel, Qwen3NextConfig + + >>> # Initializing a Qwen3Next style configuration + >>> configuration = Qwen3NextConfig() + + >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration + >>> model = Qwen3NextModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "qwen3_next" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=151936, + hidden_size=2048, + intermediate_size=5632, + num_hidden_layers=48, + num_attention_heads=16, + num_key_value_heads=2, + hidden_act="silu", + max_position_embeddings=32768, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=0.25, + attention_bias=False, + attention_dropout=0.0, + head_dim=256, + linear_conv_kernel_dim=4, + linear_key_head_dim=128, + linear_value_head_dim=128, + linear_num_key_heads=16, + linear_num_value_heads=32, + decoder_sparse_step=1, + moe_intermediate_size=512, + shared_expert_intermediate_size=512, + num_experts_per_tok=10, + num_experts=512, + norm_topk_prob=True, + output_router_logits=False, + router_aux_loss_coef=0.001, + mlp_only_layers=[], + layer_types=None, + **kwargs, + ): + super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.head_dim = head_dim + rope_config_validation(self) + + # linear attention (gdn now part) + self.linear_conv_kernel_dim = linear_conv_kernel_dim + self.linear_key_head_dim = linear_key_head_dim + self.linear_value_head_dim = linear_value_head_dim + self.linear_num_key_heads = linear_num_key_heads + self.linear_num_value_heads = linear_num_value_heads + + # MoE arguments + self.decoder_sparse_step = decoder_sparse_step + self.moe_intermediate_size = moe_intermediate_size + self.shared_expert_intermediate_size = shared_expert_intermediate_size + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.mlp_only_layers = mlp_only_layers + + @property + def layers_block_type(self): + layer_type_list = [] + + for l in range(self.num_hidden_layers): + if (l + 1) % self.full_attention_interval == 0: + layer_type_list.append(HybridLayerType.full_attention.value) + else: + layer_type_list.append(HybridLayerType.linear_attention.value) + + return layer_type_list + + @property + def linear_layer_ids(self): + return [ + i + for i, type_value in enumerate(self.layers_block_type) + if type_value == HybridLayerType.linear_attention.value + ] + + @property + def full_attention_layer_ids(self): + return [ + i + for i, type_value in enumerate(self.layers_block_type) + if type_value == HybridLayerType.full_attention.value + ] + + @property + def hybrid_gdn_params(self): + world_size = get_attention_tp_size() + conv_dim = ( + self.linear_key_head_dim * self.linear_num_key_heads * 2 + + self.linear_value_head_dim * self.linear_num_value_heads + ) + conv_state_shape = ( + divide(conv_dim, world_size), + self.linear_conv_kernel_dim - 1, + ) + + temporal_state_shape = ( + divide(self.linear_num_value_heads, world_size), + self.linear_key_head_dim, + self.linear_value_head_dim, + ) + conv_dtype = torch.bfloat16 + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] + mamba_layers = self.linear_layer_ids + return ( + conv_state_shape, + temporal_state_shape, + conv_dtype, + ssm_dtype, + mamba_layers, + ) + + @property + def mamba_cache_per_req(self): + conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = ( + self.hybrid_gdn_params + ) + mamba_layers_len = len(mamba_layers) + + return ( + int(np.prod(conv_state_shape)) * conv_dtype.itemsize + + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize + ) * mamba_layers_len diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 2f500ae79..d7dcf8904 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -42,6 +42,7 @@ from sglang.srt.configs import ( KimiVLConfig, LongcatFlashConfig, MultiModalityConfig, + Qwen3NextConfig, Step3VLConfig, ) from sglang.srt.configs.internvl import InternVLChatConfig @@ -58,6 +59,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { InternVLChatConfig.model_type: InternVLChatConfig, Step3VLConfig.model_type: Step3VLConfig, LongcatFlashConfig.model_type: LongcatFlashConfig, + Qwen3NextConfig.model_type: Qwen3NextConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py new file mode 100644 index 000000000..9730df726 --- /dev/null +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -0,0 +1,581 @@ +from dataclasses import astuple, dataclass +from functools import lru_cache +from typing import Optional, Union + +import torch +import torch.nn.functional as F + +from sglang.srt.layers.attention.base_attn_backend import AttentionBackend +from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule +from sglang.srt.layers.attention.fla.fused_recurrent import ( + fused_recurrent_gated_delta_rule_update, +) +from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( + fused_sigmoid_gating_delta_rule_update, +) +from sglang.srt.layers.attention.mamba.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating +from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + + +@dataclass +class ForwardMetadata: + query_start_loc: Optional[torch.Tensor] + mamba_cache_indices: torch.Tensor + + +class MambaAttnBackend(AttentionBackend): + """Attention backend using Mamba kernel.""" + + def __init__(self, model_runner: ModelRunner): + super().__init__() + self.pad_slot_id = -1 # Default pad slot id + self.device = model_runner.device + self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool + self.forward_metadata: ForwardMetadata = None + self.state_indices_list = [] + self.query_start_loc_list = [] + + @classmethod + @lru_cache(maxsize=128) + def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor: + """Cache torch.arange tensors for common batch sizes to avoid repeated allocation.""" + device = torch.device(device_str) + return torch.arange(0, bs + 1, dtype=torch.int32, device=device) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + bs = forward_batch.batch_size + if forward_batch.forward_mode.is_decode_or_idle(): + query_start_loc = self._get_cached_arange(bs, str(self.device)) + elif forward_batch.forward_mode.is_extend(): + if forward_batch.forward_mode.is_target_verify(): + query_start_loc = torch.arange( + 0, + forward_batch.input_ids.shape[0] + 1, + step=forward_batch.spec_info.draft_token_num, + dtype=torch.int32, + device=forward_batch.input_ids.device, + ) + else: + query_start_loc = torch.empty( + (bs + 1,), dtype=torch.int32, device=self.device + ) + query_start_loc[:bs] = forward_batch.extend_start_loc + query_start_loc[bs] = ( + forward_batch.extend_start_loc[-1] + + forward_batch.extend_seq_lens[-1] + ) + else: + raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}") + mamba_cache_indices = self.req_to_token_pool.get_mamba_indices( + forward_batch.req_pool_indices + ) + self.forward_metadata = ForwardMetadata( + query_start_loc=query_start_loc, + mamba_cache_indices=mamba_cache_indices, + ) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for i in range(max_bs): + self.state_indices_list.append( + torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda") + ) + self.query_start_loc_list.append( + torch.empty((i + 2,), dtype=torch.int32, device="cuda") + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + if forward_mode.is_decode_or_idle(): + self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda")) + elif forward_mode.is_target_verify(): + self.query_start_loc_list[bs - 1].copy_( + torch.arange( + 0, + bs * spec_info.draft_token_num + 1, + step=spec_info.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + ) + else: + raise ValueError(f"Invalid forward mode: {forward_mode=}") + mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) + self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) + self.forward_metadata = ForwardMetadata( + query_start_loc=self.query_start_loc_list[bs - 1], + mamba_cache_indices=self.state_indices_list[bs - 1], + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + num_padding = torch.count_nonzero( + seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value() + ) + # Make sure forward metadata is correctly handled for padding reqs + req_pool_indices[bs - num_padding :] = 0 + mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices) + mamba_indices[bs - num_padding :] = -1 + self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) + if forward_mode.is_decode_or_idle(): + self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda")) + if num_padding > 0: + self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding + elif forward_mode.is_target_verify(): + self.query_start_loc_list[bs - 1].copy_( + torch.arange( + 0, + bs * spec_info.draft_token_num + 1, + step=spec_info.draft_token_num, + dtype=torch.int32, + device="cuda", + ) + ) + if num_padding > 0: + self.query_start_loc_list[bs - 1][bs - num_padding :] = ( + bs - num_padding + ) * spec_info.draft_token_num + else: + raise ValueError(f"Invalid forward mode: {forward_mode=}") + + self.forward_metadata = ForwardMetadata( + query_start_loc=self.query_start_loc_list[bs - 1], + mamba_cache_indices=self.state_indices_list[bs - 1], + ) + + def get_cuda_graph_seq_len_fill_value(self): + return 1 # Mamba attn does not use seq lens to index kv cache + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + mixed_qkv = kwargs["mixed_qkv"] + conv_weights = kwargs["conv_weights"] + bias = kwargs["bias"] + activation = kwargs["activation"] + key_dim = kwargs["key_dim"] + value_dim = kwargs["value_dim"] + attn_tp_size = kwargs["attention_tp_size"] + head_k_dim = kwargs["head_k_dim"] + head_v_dim = kwargs["head_v_dim"] + a = kwargs["a"] + b = kwargs["b"] + A_log = kwargs["A_log"] + dt_bias = kwargs["dt_bias"] + layer_id = kwargs["layer_id"] + + conv_states, ssm_states = self.req_to_token_pool.get_mamba_params(layer_id) + query_start_loc = self.forward_metadata.query_start_loc + cache_indices = self.forward_metadata.mamba_cache_indices + + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation, + conv_state_indices=cache_indices, + ) + + query, key, value = torch.split( + mixed_qkv, + [ + key_dim // attn_tp_size, + key_dim // attn_tp_size, + value_dim // attn_tp_size, + ], + dim=-1, + ) + # Reshape from [l, h*d] to [1, l, h, d] + seq_len = query.shape[0] + num_heads = query.shape[1] // head_k_dim + query = query.view(1, seq_len, num_heads, head_k_dim) + key = key.view(1, seq_len, num_heads, head_k_dim) + value = value.view(1, seq_len, value.shape[1] // head_v_dim, head_v_dim) + + core_attn_out = fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + dt_bias=dt_bias, + q=query, + k=key, + v=value, + a=a, + b=b, + initial_state_source=ssm_states, + initial_state_indices=cache_indices, + cu_seqlens=query_start_loc, + use_qk_l2norm_in_kernel=True, + softplus_beta=1.0, + softplus_threshold=20.0, + ) + + return core_attn_out + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + mixed_qkv = kwargs["mixed_qkv"] + conv_weights = kwargs["conv_weights"] + bias = kwargs["bias"] + activation = kwargs["activation"] + key_dim = kwargs["key_dim"] + value_dim = kwargs["value_dim"] + attn_tp_size = kwargs["attention_tp_size"] + head_k_dim = kwargs["head_k_dim"] + head_v_dim = kwargs["head_v_dim"] + a = kwargs["a"] + b = kwargs["b"] + A_log = kwargs["A_log"] + dt_bias = kwargs["dt_bias"] + layer_id = kwargs["layer_id"] + seq_len = kwargs["seq_len"] + + is_target_verify = forward_batch.forward_mode.is_target_verify() + + query_start_loc = self.forward_metadata.query_start_loc + cache_indices = self.forward_metadata.mamba_cache_indices + + if is_target_verify: + ( + conv_states, + ssm_states, + mixed_qkv_cache, + intermediate_state_cache, + ) = self.req_to_token_pool.get_mamba_params(layer_id) + mixed_qkv_cache[cache_indices] = mixed_qkv.view( + (-1,) + mixed_qkv_cache.shape[1:] + ).clone() + has_initial_states = torch.ones( + seq_len // forward_batch.spec_info.draft_token_num, + dtype=torch.bool, + device=forward_batch.input_ids.device, + ) + conv_states_to_use = conv_states.clone() + else: + conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params( + layer_id + ) + has_initial_states = forward_batch.extend_prefix_lens > 0 + conv_states_to_use = conv_states + mixed_qkv = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + conv_weights, + bias, + activation=activation, + conv_states=conv_states_to_use, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ).transpose(0, 1)[:seq_len] + + key_split_dim = key_dim // attn_tp_size + value_split_dim = value_dim // attn_tp_size + + query, key, value = torch.split( + mixed_qkv, + [key_split_dim, key_split_dim, value_split_dim], + dim=-1, + ) + + actual_seq_len = query.shape[0] + num_heads = query.shape[1] // head_k_dim + num_value_heads = value.shape[1] // head_v_dim + + query = query.view(1, actual_seq_len, num_heads, head_k_dim) + key = key.view(1, actual_seq_len, num_heads, head_k_dim) + value = value.view(1, actual_seq_len, num_value_heads, head_v_dim) + + beta = b.sigmoid() + g = fused_gdn_gating(A_log, a, dt_bias) + + g = g.unsqueeze(0) + beta = beta.unsqueeze(0) + + if is_target_verify: + core_attn_out = fused_recurrent_gated_delta_rule_update( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state_source=ssm_states, + initial_state_indices=cache_indices, + cu_seqlens=query_start_loc, + use_qk_l2norm_in_kernel=True, + disable_state_update=True, + intermediate_states_buffer=intermediate_state_cache, + cache_steps=forward_batch.spec_info.draft_token_num, + ) + else: + recurrent_state = ssm_states[cache_indices] + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=True, + cu_seqlens=query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False) + ssm_states[cache_indices] = last_recurrent_state + + return core_attn_out + + +class HybridLinearAttnBackend(AttentionBackend): + """Support different backends for prefill and decode.""" + + def __init__( + self, + full_attn_backend: AttentionBackend, + linear_attn_backend: AttentionBackend, + full_attn_layers: list[int], + ): + self.full_attn_layers = full_attn_layers + self.attn_backend_list = [full_attn_backend, linear_attn_backend] + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for attn_backend in self.attn_backend_list: + attn_backend.init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for attn_backend in self.attn_backend_list: + attn_backend.init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + ): + for attn_backend in self.attn_backend_list: + attn_backend.init_forward_metadata_capture_cuda_graph( + bs, + num_tokens, + req_pool_indices, + seq_lens, + encoder_lens, + forward_mode, + spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + seq_lens_cpu: Optional[torch.Tensor], + ): + for attn_backend in self.attn_backend_list: + attn_backend.init_forward_metadata_replay_cuda_graph( + bs, + req_pool_indices, + seq_lens, + seq_lens_sum, + encoder_lens, + forward_mode, + spec_info, + seq_lens_cpu, + ) + + def get_cuda_graph_seq_len_fill_value(self): + return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value() + + def forward_decode( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + layer_id = layer.layer_id if layer else kwargs["layer_id"] + if layer_id in self.full_attn_layers: + return self.attn_backend_list[0].forward_decode( + q, k, v, layer, forward_batch, save_kv_cache, **kwargs + ) + return self.attn_backend_list[1].forward_decode( + q, k, v, layer, forward_batch, save_kv_cache, **kwargs + ) + + def forward_extend( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + layer_id = layer.layer_id if layer else kwargs["layer_id"] + if layer_id in self.full_attn_layers: + return self.attn_backend_list[0].forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, **kwargs + ) + return self.attn_backend_list[1].forward_extend( + q, k, v, layer, forward_batch, save_kv_cache, **kwargs + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer: RadixAttention, + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ): + """Run forward on an attention layer.""" + if forward_batch.forward_mode.is_idle(): + if layer is None: + return torch.empty_like(kwargs["z"]) + return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + elif forward_batch.forward_mode.is_decode(): + return self.forward_decode( + q, + k, + v, + layer, + forward_batch, + save_kv_cache=save_kv_cache, + **kwargs, + ) + else: + return self.forward_extend( + q, + k, + v, + layer, + forward_batch, + save_kv_cache=save_kv_cache, + **kwargs, + ) + + def update_mamba_state_after_mtp_verify(self, accepted_length, model): + request_number = accepted_length.shape[0] + # QQ: step = spec num_draft token num + num_draft_tokens = ( + self.attn_backend_list[1] + .req_to_token_pool.mamba_pool.mamba_cache[2] + .shape[2] + ) + query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype) + query_start_loc = torch.cat( + [ + torch.zeros( + 1, + dtype=query_start_loc.dtype, + device=query_start_loc.device, + ), + query_start_loc, + ] + ) + mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze( + 0 + ) < accepted_length.unsqueeze(1) + + state_indices_tensor = self.attn_backend_list[ + 1 + ].forward_metadata.mamba_cache_indices[:request_number] + + mamba_caches = self.attn_backend_list[ + 1 + ].req_to_token_pool.get_mamba_params_all_layers() + + conv_states, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches + + mixed_qkvs = mix_qkv_cache[:, state_indices_tensor][:, mask] + + mamba_map = self.attn_backend_list[1].req_to_token_pool.mamba_map + + has_initial_states = torch.ones( + request_number, dtype=torch.bool, device=accepted_length.device + ) + + # Batch SSM state updates (outside the loop for efficiency) + valid_mask = accepted_length > 0 + if intermediate_state_cache is not None: + last_steps = (accepted_length - 1).to(torch.int64) + valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) + + ssm_states[:, valid_state_indices, :] = intermediate_state_cache[ + :, valid_state_indices, last_steps + ].to(ssm_states.dtype) + + # For loop conv state updates (can be optimized) + for i in range(len(model.model.layers)): + layer = model.model.layers[i] + if isinstance(layer, Qwen3HybridLinearDecoderLayer): + conv_weights = layer.linear_attn.conv1d.weight.view( + layer.linear_attn.conv1d.weight.size(0), + layer.linear_attn.conv1d.weight.size(2), + ) + + layer_id = mamba_map[i] + conv_state = conv_states[layer_id] + mixed_qkv = mixed_qkvs[layer_id] + + _ = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + conv_weights, + layer.linear_attn.conv1d.bias, + activation=layer.linear_attn.activation, + conv_states=conv_state, + has_initial_state=has_initial_states, + cache_indices=state_indices_tensor, + query_start_loc=query_start_loc, + ) diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d.py new file mode 100644 index 000000000..d004337ff --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d.py @@ -0,0 +1,128 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py + +from typing import Optional + +import torch +from sgl_kernel import causal_conv1d_fwd +from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel + +PAD_SLOT_ID = -1 + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + causal_conv1d_fwd( + x, + weight, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation in ["silu", "swish"], + pad_slot_id, + ) + return x + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError( + f"activation must be None, silu, or swish, actual: {activation}" + ) + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + causal_conv1d_update_kernel( + x, + conv_state, + weight, + bias, + activation_val, + cache_seqlens, + conv_state_indices, + pad_slot_id, + ) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py new file mode 100644 index 000000000..045a04048 --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -0,0 +1,64 @@ +from typing import Callable, List, Tuple + +import torch + +LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None] + + +def mamba_v2_sharded_weight_loader( + shard_spec: List[Tuple[int, int, float]], + tp_size: int, + tp_rank: int, +) -> LoaderFunction: + """Create a weight loader for mamba v2. This ensures that the projections + are correctly sharded so that they can be split into x, B, C. It also + ensures the the all the groups corresponding to a head shard is placed + together with it. + """ + + def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: + + # - track boundary of (sharded) param, and loaded_weight, respectively + boundary, loaded_boundary = 0, 0 + + # - iterate over the shard specs + for full_dim, extra, duplicate_groups in shard_spec: + # - full dim is the model dim (before TP). + # - extra > 0, means there is expected overall increase + # of dimensions. This is so because of replication. + # - ratio is used map the tp_rank to the actual shard + # rank. This is useful when there is replication of + # groups to accompany head shards. + + # - size of the loaded shard + shard_size = full_dim // tp_size + + # - compute the rank into the loaded shard. + # - if there is replication, different TP shards will + # take from the same rank. + # NOTE: currently we only support duplication + # in the case where num_groups == 1 + rank = 0 if duplicate_groups else tp_rank + + # - leftmost boundary index into loaded weight. + loaded_skip = rank * shard_size + loaded_start_idx = loaded_boundary + loaded_skip + + # - take these many dims from the loaded weight. + take = min(shard_size, full_dim - extra - loaded_skip) + + # - always shard on dim 0 + # - the ignore is for a mundane mypy error as it does not + # seem to handle slices well. + # https://github.com/python/mypy/issues/2410 + param.data[ + boundary : (boundary + take), ... # type: ignore[misc] + ] = loaded_weight[ + loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc] + ] # type: ignore[misc] + + # move indexing boundaries + boundary += shard_size + loaded_boundary += full_dim - extra + + return loader diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..b8f35b62e --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..64861b390 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index f519224df..c0c0917ac 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -38,7 +38,7 @@ import threading from enum import Enum, auto from http import HTTPStatus from itertools import chain -from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch @@ -59,7 +59,7 @@ from sglang.srt.mem_cache.allocator import ( from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache -from sglang.srt.mem_cache.memory_pool import ReqToTokenPool +from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.metrics.collector import TimeStats from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode @@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def is_empty(self): return len(self.reqs) == 0 - def alloc_req_slots(self, num_reqs: int): - req_pool_indices = self.req_to_token_pool.alloc(num_reqs) + def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None): + if isinstance(self.req_to_token_pool, HybridReqToTokenPool): + req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs) + else: + req_pool_indices = self.req_to_token_pool.alloc(num_reqs) if req_pool_indices is None: raise RuntimeError( "alloc_req_slots runs out of memory. " @@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Allocate req slots bs = len(self.reqs) - req_pool_indices = self.alloc_req_slots(bs) + req_pool_indices = self.alloc_req_slots(bs, self.reqs) # Init tensors reqs = self.reqs diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 9e3af2eaa..5b80afcc1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1540,7 +1540,12 @@ class Scheduler( chunked_req_to_exclude.add(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) # chunked request keeps its rid but will get a new req_pool_idx - self.req_to_token_pool.free(self.chunked_req.req_pool_idx) + if self.tp_worker.worker.model_runner.is_hybrid_gdn: + self.req_to_token_pool.free( + self.chunked_req.req_pool_idx, free_mamba_cache=False + ) + else: + self.req_to_token_pool.free(self.chunked_req.req_pool_idx) if self.last_batch and self.last_batch.forward_mode.is_extend(): if self.last_batch.chunked_req is not None: # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 175440a3f..6cc66ba1a 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -102,6 +102,204 @@ class ReqToTokenPool: self.free_slots = list(range(self.size)) +class MambaPool: + def __init__( + self, + size: int, + conv_dtype: torch.dtype, + ssm_dtype: torch.dtype, + num_mamba_layers: int, + conv_state_shape: Tuple[int, int], + temporal_state_shape: Tuple[int, int], + device: str, + speculative_num_draft_tokens: Optional[int] = None, + ): + conv_state = torch.zeros( + size=(num_mamba_layers, size + 1) + conv_state_shape, + dtype=conv_dtype, + device=device, + ) + temporal_state = torch.zeros( + size=(num_mamba_layers, size + 1) + temporal_state_shape, + dtype=ssm_dtype, + device=device, + ) + if speculative_num_draft_tokens is not None: + mixed_qkv_cache = torch.empty( + size=( + num_mamba_layers, + size + 1, + speculative_num_draft_tokens, + conv_state_shape[0], + ), + dtype=conv_dtype, + device="cuda", + ) + # Cache intermediate SSM states per draft token during target verify + # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] + intermediate_ssm_state_cache = torch.empty( + size=( + num_mamba_layers, + size + 1, + speculative_num_draft_tokens, + temporal_state_shape[0], + temporal_state_shape[1], + temporal_state_shape[2], + ), + dtype=ssm_dtype, + device="cuda", + ) + self.mamba_cache = ( + conv_state, + temporal_state, + mixed_qkv_cache, + intermediate_ssm_state_cache, + ) + else: + self.mamba_cache = (conv_state, temporal_state) + self.size = size + self.free_slots = list(range(size)) + self.mem_usage = self.get_mamba_size() / GB + logger.info( + f"Mamba Cache is allocated. " + f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, " + f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB " + ) + + def get_mamba_params_all_layers(self): + return [self.mamba_cache[i] for i in range(len(self.mamba_cache))] + + def get_mamba_params(self, layer_id: int): + return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))] + + def get_mamba_size(self): + return ( + np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize + + np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize + ) + + def available_size(self): + return len(self.free_slots) + + def alloc(self, need_size: int) -> Optional[List[int]]: + if need_size > len(self.free_slots): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + return select_index + + def free(self, free_index: Union[int, List[int]]): + if isinstance(free_index, (int,)): + self.free_slots.append(free_index) + else: + self.free_slots.extend(free_index) + self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0 + + def clear(self): + self.free_slots = list(range(self.size)) + + +class HybridReqToTokenPool(ReqToTokenPool): + """A memory pool that maps a request to its token locations.""" + + def __init__( + self, + size: int, + max_context_len: int, + device: str, + enable_memory_saver: bool, + conv_dtype: torch.dtype, + ssm_dtype: torch.dtype, + mamba_layers: List[int], + conv_state_shape: Tuple[int, int], + temporal_state_shape: Tuple[int, int], + speculative_num_draft_tokens: int, + ): + super().__init__( + size=size, + max_context_len=max_context_len, + device=device, + enable_memory_saver=enable_memory_saver, + ) + + self.mamba_pool = MambaPool( + size, + conv_dtype, + ssm_dtype, + len(mamba_layers), + conv_state_shape, + temporal_state_shape, + device, + speculative_num_draft_tokens, + ) + self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)} + + self.device = device + self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty( + size, dtype=torch.int32, device=self.device + ) + + self.rid_to_mamba_index_mapping: Dict[str, int] = {} + self.mamba_index_to_rid_mapping: Dict[int, str] = {} + + # For chunk prefill req, we do not need to allocate mamba cache, + # We could use allocated mamba cache instead. + def alloc( + self, need_size: int, reqs: Optional[List["Req"]] = None + ) -> Optional[List[int]]: + select_index = super().alloc(need_size) + if select_index == None: + return None + + mamba_index = [] + for req in reqs: + rid = req.rid + if rid in self.rid_to_mamba_index_mapping: + mid = self.rid_to_mamba_index_mapping[rid] + elif (mid := self.mamba_pool.alloc(1)) is not None: + mid = mid[0] + self.rid_to_mamba_index_mapping[rid] = mid + self.mamba_index_to_rid_mapping[mid] = rid + mamba_index.append(mid) + assert len(select_index) == len( + mamba_index + ), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size." + self.req_index_to_mamba_index_mapping[select_index] = torch.tensor( + mamba_index, dtype=torch.int32, device=self.device + ) + return select_index + + def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor: + return self.req_index_to_mamba_index_mapping[req_indices] + + def get_mamba_params(self, layer_id: int): + assert layer_id in self.mamba_map + return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id]) + + def get_mamba_params_all_layers(self): + return self.mamba_pool.get_mamba_params_all_layers() + + # For chunk prefill, we can not free mamba cache, we need use it in the future + def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True): + super().free(free_index) + if free_mamba_cache: + mamba_index = self.req_index_to_mamba_index_mapping[free_index] + mamba_index_list = mamba_index.tolist() + if isinstance(mamba_index_list, int): + mamba_index_list = [mamba_index_list] + self.mamba_pool.free(mamba_index_list) + for mid in mamba_index_list: + rid = self.mamba_index_to_rid_mapping[mid] + self.mamba_index_to_rid_mapping.pop(mid) + self.rid_to_mamba_index_mapping.pop(rid) + + def clear(self): + super().clear() + self.mamba_pool.clear() + + class KVCache(abc.ABC): @abc.abstractmethod def __init__( @@ -441,6 +639,88 @@ class MHATokenToKVPool(KVCache): ) +class HybridLinearKVPool(KVCache): + """KV cache with separate pools for full and linear attention layers.""" + + def __init__( + self, + size: int, + dtype: torch.dtype, + head_num: int, + head_dim: int, + full_attention_layer_ids: List[int], + enable_kvcache_transpose: bool, + device: str, + ): + self.size = size + self.dtype = dtype + self.device = device + self.full_layer_nums = len(full_attention_layer_ids) + self.page_size = 1 + # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True + assert not enable_kvcache_transpose + self.full_kv_pool = MHATokenToKVPool( + size=size, + page_size=self.page_size, + dtype=dtype, + head_num=head_num, + head_dim=head_dim, + layer_num=self.full_layer_nums, + device=device, + enable_memory_saver=False, + ) + self.full_attention_layer_id_mapping = { + id: i for i, id in enumerate(full_attention_layer_ids) + } + k_size, v_size = self.get_kv_size_bytes() + self.mem_usage = (k_size + v_size) / GB + + def get_kv_size_bytes(self): + return self.full_kv_pool.get_kv_size_bytes() + + def get_contiguous_buf_infos(self): + return self.full_kv_pool.get_contiguous_buf_infos() + + def _transfer_full_attention_id(self, layer_id: int): + if layer_id not in self.full_attention_layer_id_mapping: + raise ValueError( + f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}" + ) + return self.full_attention_layer_id_mapping[layer_id] + + def get_key_buffer(self, layer_id: int): + layer_id = self._transfer_full_attention_id(layer_id) + return self.full_kv_pool.get_key_buffer(layer_id) + + def get_value_buffer(self, layer_id: int): + layer_id = self._transfer_full_attention_id(layer_id) + return self.full_kv_pool.get_value_buffer(layer_id) + + def get_kv_buffer(self, layer_id: int): + layer_id = self._transfer_full_attention_id(layer_id) + return self.full_kv_pool.get_kv_buffer(layer_id) + + def set_kv_buffer( + self, + layer: RadixAttention, + loc: torch.Tensor, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + k_scale: float = 1.0, + v_scale: float = 1.0, + ): + layer_id = self._transfer_full_attention_id(layer.layer_id) + self.full_kv_pool.set_kv_buffer( + None, + loc, + cache_k, + cache_v, + k_scale, + v_scale, + layer_id_override=layer_id, + ) + + class SWAKVPool(KVCache): """KV cache with separate pools for full and SWA attention layers.""" diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 56cdee7a2..aa0e2e0e6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -85,6 +85,8 @@ from sglang.srt.mem_cache.memory_pool import ( AscendMLAPagedTokenToKVPool, AscendTokenToKVPool, DoubleSparseTokenToKVPool, + HybridLinearKVPool, + HybridReqToTokenPool, MHATokenToKVPool, MLATokenToKVPool, ReqToTokenPool, @@ -303,6 +305,26 @@ class ModelRunner: if architectures and not any("Llama4" in arch for arch in architectures): self.is_hybrid = self.model_config.is_hybrid = True + if self.is_hybrid_gdn: + logger.warning("Hybrid GDN model detected, disable radix cache") + self.server_args.disable_radix_cache = True + self.server_args.attention_backend = "hybrid_linear_attn" + if self.server_args.max_mamba_cache_size is None: + if self.server_args.max_running_requests is not None: + self.server_args.max_mamba_cache_size = ( + self.server_args.max_running_requests + ) + else: + self.server_args.max_mamba_cache_size = 512 + self.server_args.max_mamba_cache_size = ( + self.server_args.max_mamba_cache_size + // ( + self.server_args.dp_size + if self.server_args.enable_dp_attention + else 1 + ) + ) + # For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft # models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to # determine the number of layers. @@ -1080,6 +1102,8 @@ class ModelRunner: "num_nextn_predict_layers", self.num_effective_layers, ) + elif self.is_hybrid_gdn: + num_layers = len(self.model_config.hf_config.full_attention_layer_ids) else: num_layers = self.num_effective_layers if self.use_mla_backend: @@ -1099,9 +1123,22 @@ class ModelRunner: rest_memory = available_gpu_memory - total_gpu_memory * ( 1 - self.mem_fraction_static ) + if self.is_hybrid_gdn: + rest_memory -= ( + self.server_args.max_mamba_cache_size + * self.model_config.hf_config.mamba_cache_per_req + / (1 << 30) + ) max_num_token = int(rest_memory * (1 << 30) // cell_size) return max_num_token + @property + def is_hybrid_gdn(self): + return self.model_config.hf_config.architectures[0] in [ + "Qwen3NextForCausalLM", + "Qwen3NextForCausalLMMTP", + ] + def set_num_token_hybrid(self): if ( "Llama4ForConditionalGeneration" @@ -1222,6 +1259,8 @@ class ModelRunner: ), 4096, ) + if self.is_hybrid_gdn: + max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) if not self.spec_algorithm.is_none(): if self.is_draft_worker: @@ -1300,6 +1339,28 @@ class ModelRunner: enable_memory_saver=self.server_args.enable_memory_saver, pre_alloc_size=pre_alloc_size, ) + elif self.is_hybrid_gdn: + config = self.model_config.hf_config + ( + conv_state_shape, + temporal_state_shape, + conv_dtype, + ssm_dtype, + mamba_layers, + ) = config.hybrid_gdn_params + self.req_to_token_pool = HybridReqToTokenPool( + size=max_num_reqs, + max_context_len=self.model_config.context_len + + extra_max_context_len, + device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, + conv_state_shape=conv_state_shape, + temporal_state_shape=temporal_state_shape, + conv_dtype=conv_dtype, + ssm_dtype=ssm_dtype, + mamba_layers=mamba_layers, + speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, + ) else: self.req_to_token_pool = ReqToTokenPool( size=max_num_reqs, @@ -1382,6 +1443,23 @@ class ModelRunner: enable_kvcache_transpose=False, device=self.device, ) + elif self.is_hybrid_gdn: + self.token_to_kv_pool = HybridLinearKVPool( + size=self.max_total_num_tokens, + dtype=self.kv_cache_dtype, + head_num=self.model_config.get_num_kv_heads( + get_attention_tp_size() + ), + head_dim=self.model_config.head_dim, + # if draft worker, we only need 1 attention layer's kv pool + full_attention_layer_ids=( + [0] + if self.is_draft_worker + else self.model_config.hf_config.full_attention_layer_ids + ), + enable_kvcache_transpose=False, + device=self.device, + ) else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -1615,6 +1693,24 @@ class ModelRunner: ) return DualChunkFlashAttentionBackend(self) + elif backend_str == "hybrid_linear_attn": + assert ( + self.is_hybrid_gdn + ), "hybrid_linear_attn backend can only be used with hybrid GDN models." + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( + HybridLinearAttnBackend, + MambaAttnBackend, + ) + + full_attn_backend = FlashAttentionBackend(self) + linear_attn_backend = MambaAttnBackend(self) + full_attn_layers = self.model_config.hf_config.full_attention_layer_ids + return HybridLinearAttnBackend( + full_attn_backend, linear_attn_backend, full_attn_layers + ) else: raise ValueError(f"Invalid attention backend: {backend_str}") diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index a326e3f10..397d9e913 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -35,6 +35,7 @@ from tqdm.auto import tqdm from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config from sglang.srt.utils import print_warning_once @@ -680,7 +681,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction: """Create a weight loader that shards the weights along the given axis""" def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: - tp_rank = get_tensor_model_parallel_rank() + tp_rank = get_attention_tp_rank() shard_size = param.data.shape[shard_axis] start_idx = tp_rank * shard_size diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py new file mode 100644 index 000000000..fd0d0e942 --- /dev/null +++ b/python/sglang/srt/models/qwen3_next.py @@ -0,0 +1,1072 @@ +import enum +import logging +from typing import Any, Dict, Iterable, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from sglang.srt.configs.qwen3_next import Qwen3NextConfig +from sglang.srt.distributed import ( + divide, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated +from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + sharded_weight_loader, +) +from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock +from sglang.srt.utils import add_prefix, is_cuda, make_layers, set_weight_attrs + +logger = logging.getLogger(__name__) +_is_cuda = is_cuda() + +import triton +import triton.language as tl + + +@triton.jit +def fused_qkvzba_split_reshape_cat_kernel( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + NUM_HEADS_QK: tl.constexpr, + NUM_HEADS_V: tl.constexpr, + HEAD_QK: tl.constexpr, + HEAD_V: tl.constexpr, +): + i_bs, i_qk = tl.program_id(0), tl.program_id(1) + QKVZ_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V * 2 + BA_DIM_T: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK * 2 + QKV_DIM_T: tl.constexpr = HEAD_QK * 2 + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + q_end: tl.constexpr = HEAD_QK + blk_q_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(0, q_end) + ) + k_end: tl.constexpr = q_end + HEAD_QK + blk_k_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(q_end, k_end) + ) + v_end: tl.constexpr = k_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_v_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(k_end, v_end) + ) + z_end: tl.constexpr = v_end + NUM_HEADS_V // NUM_HEADS_QK * HEAD_V + blk_z_ptr = ( + mixed_qkvz + + i_bs * NUM_HEADS_QK * QKVZ_DIM_T + + i_qk * QKVZ_DIM_T + + tl.arange(v_end, z_end) + ) + blk_q_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK) + ) + blk_k_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK + + i_qk * HEAD_QK + + tl.arange(0, HEAD_QK) + ) + blk_v_st_ptr = ( + mixed_qkv + + i_bs * NUM_HEADS_QK * QKV_DIM_T + + NUM_HEADS_QK * HEAD_QK * 2 + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + blk_z_st_ptr = ( + z + + i_bs * NUM_HEADS_V * HEAD_V + + i_qk * HEAD_V * NUM_HEADS_V // NUM_HEADS_QK + + tl.arange(0, HEAD_V * NUM_HEADS_V // NUM_HEADS_QK) + ) + tl.store(blk_q_st_ptr, tl.load(blk_q_ptr)) + tl.store(blk_k_st_ptr, tl.load(blk_k_ptr)) + tl.store(blk_v_st_ptr, tl.load(blk_v_ptr)) + tl.store(blk_z_st_ptr, tl.load(blk_z_ptr)) + b_end: tl.constexpr = NUM_HEADS_V // NUM_HEADS_QK + a_end: tl.constexpr = b_end + NUM_HEADS_V // NUM_HEADS_QK + for i in tl.static_range(b_end): + blk_b_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_b_st_ptr = b + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + i + tl.store(blk_b_st_ptr, tl.load(blk_b_ptr)) + for i in tl.static_range(b_end, a_end): + blk_a_ptr = mixed_ba + i_bs * NUM_HEADS_QK * BA_DIM_T + i_qk * BA_DIM_T + i + blk_a_st_ptr = ( + a + i_bs * NUM_HEADS_V + i_qk * NUM_HEADS_V // NUM_HEADS_QK + (i - b_end) + ) + tl.store(blk_a_st_ptr, tl.load(blk_a_ptr)) + + +def fused_qkvzba_split_reshape_cat( + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, +): + batch, seq_len = mixed_qkvz.shape[0], 1 + qkv_dim_t = num_heads_qk * head_qk * 2 + num_heads_v * head_v + mixed_qkv = torch.empty( + [batch * seq_len, qkv_dim_t], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + z = torch.empty( + [batch * seq_len, num_heads_v, head_v], + dtype=mixed_qkvz.dtype, + device=mixed_qkvz.device, + ) + b = torch.empty( + [batch * seq_len, num_heads_v], + dtype=mixed_ba.dtype, + device=mixed_ba.device, + ) + a = torch.empty_like(b) + grid = (batch * seq_len, num_heads_qk) + fused_qkvzba_split_reshape_cat_kernel[grid]( + mixed_qkv, + z, + b, + a, + mixed_qkvz, + mixed_ba, + num_heads_qk, + num_heads_v, + head_qk, + head_v, + num_warps=1, + num_stages=3, + ) + return mixed_qkv, z, b, a + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid]( + g, A_log, a, dt_bias, seq_len, num_heads, beta, threshold, 8, num_warps=1 + ) + return g + + +class Qwen3GatedDeltaNet(nn.Module): + def __init__( + self, + config: Qwen3NextConfig, + layer_id: int, + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.attn_tp_rank = get_attention_tp_rank() + self.attn_tp_size = get_attention_tp_size() + self.hidden_size = config.hidden_size + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.alt_stream = alt_stream + + self.conv_kernel_size = config.linear_conv_kernel_dim + self.layer_id = layer_id + self.activation = config.hidden_act + self.layer_norm_epsilon = config.rms_norm_eps + + # QKV + self.conv_dim = self.key_dim * 2 + self.value_dim + self.conv1d = ColumnParallelLinear( + input_size=self.conv_kernel_size, + output_size=self.conv_dim, + bias=False, + quant_config=None, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + # projection of the input hidden states + projection_size_qkvz = self.key_dim * 2 + self.value_dim * 2 + projection_size_ba = self.num_v_heads * 2 + + self.in_proj_qkvz = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=projection_size_qkvz, + bias=False, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + self.in_proj_ba = ColumnParallelLinear( + input_size=self.hidden_size, + output_size=projection_size_ba, + bias=False, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + + query_key_settings = (self.key_dim, 0, False) + value_settings = (self.value_dim, 0, False) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + query_key_settings, + query_key_settings, + value_settings, + ], + self.attn_tp_size, + self.attn_tp_rank, + ) + }, + ) + + # selective projection used to make dt, B and C input dependent + + # time step projection (discretization) + # instantiate once and copy inv_dt in init_weights of PretrainedModel + self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads // self.attn_tp_size)) + + A = torch.empty( + divide(self.num_v_heads, self.attn_tp_size), dtype=torch.float32 + ).uniform_(0, 16) + self.A_log = nn.Parameter(torch.log(A)) + self.A_log._no_weight_decay = True + + set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(0)}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.norm = RMSNormGated( + self.head_v_dim, + eps=self.layer_norm_epsilon, + group_size=None, + norm_before_gate=True, + device=torch.cuda.current_device(), + dtype=config.torch_dtype, + ) + + self.out_proj = RowParallelLinear( + self.value_dim, + self.hidden_size, + bias=False, + input_is_parallel=True, + reduce_results=False, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + + def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + new_tensor_shape_qkvz = mixed_qkvz.size()[:-1] + ( + self.num_k_heads // self.attn_tp_size, + ( + self.head_k_dim + + self.head_k_dim + + (self.head_v_dim + self.head_v_dim) + * self.num_v_heads + // self.num_k_heads + ), + ) + new_tensor_shape_ba = mixed_ba.size()[:-1] + ( + self.num_k_heads // self.attn_tp_size, + 2 * self.num_v_heads // self.num_k_heads, + ) + + mixed_qkvz = mixed_qkvz.view(*new_tensor_shape_qkvz) + mixed_ba = mixed_ba.view(*new_tensor_shape_ba) + + split_arg_list_qkvz = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + (self.num_v_heads // self.num_k_heads * self.head_v_dim), + ] + split_arg_list_ba = [ + self.num_v_heads // self.num_k_heads, + self.num_v_heads // self.num_k_heads, + ] + + # [b, sq, ng, (hn + hn + np/ng * hn + np/ng + np/ng)] + # --> [b, sq, ng, hn], [b, sq, ng, hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng * hn], [b, sq, ng, np/ng], [b, sq, ng, np/ng] + (query, key, value, z) = torch.split(mixed_qkvz, split_arg_list_qkvz, dim=2) + (b, a) = torch.split(mixed_ba, split_arg_list_ba, dim=2) + + # [b, sq, ng, np/ng * hn] -> [b, sq, np, hn] + value = value.reshape(value.size(0), -1, self.head_v_dim) + z = z.reshape(z.size(0), -1, self.head_v_dim) + b = b.reshape(b.size(0), self.num_v_heads // self.attn_tp_size) + a = a.reshape(a.size(0), self.num_v_heads // self.attn_tp_size) + + return query, key, value, z, b, a + + def _forward_input_proj(self, hidden_states: torch.Tensor): + DUAL_STREAM_TOKEN_THRESHOLD = 1024 + seq_len, _ = hidden_states.shape + if seq_len < DUAL_STREAM_TOKEN_THRESHOLD: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + with torch.cuda.stream(self.alt_stream): + projected_states_ba, _ = self.in_proj_ba(hidden_states) + current_stream.wait_stream(self.alt_stream) + else: + projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) + projected_states_ba, _ = self.in_proj_ba(hidden_states) + return projected_states_qkvz, projected_states_ba + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ): + seq_len, _ = hidden_states.shape + is_cuda_graph = forward_batch.forward_mode.is_cuda_graph() + + projected_states_qkvz, projected_states_ba = self._forward_input_proj( + hidden_states + ) + + if self.num_v_heads // self.num_k_heads in [1, 2, 4] and is_cuda_graph: + mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( + projected_states_qkvz, + projected_states_ba, + triton.cdiv(self.num_k_heads, self.attn_tp_size), + triton.cdiv(self.num_v_heads, self.attn_tp_size), + self.head_k_dim, + self.head_v_dim, + ) + else: + query, key, value, z, b, a = self.fix_query_key_value_ordering( + projected_states_qkvz, projected_states_ba + ) + query, key, value = map( + lambda x: x.reshape(x.shape[0], -1), (query, key, value) + ) + mixed_qkv = torch.cat((query, key, value), dim=-1) + # mixed_qkv = rearrange(mixed_qkv, "b l d -> b d l") + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + kwargs = { + "mixed_qkv": mixed_qkv, + "conv_weights": conv_weights, + "bias": self.conv1d.bias, + "activation": self.activation, + "key_dim": self.key_dim, + "value_dim": self.value_dim, + "attention_tp_size": self.attn_tp_size, + "head_k_dim": self.head_k_dim, + "head_v_dim": self.head_v_dim, + "a": a, + "b": b, + "A_log": self.A_log, + "dt_bias": self.dt_bias, + "layer_id": self.layer_id, + "seq_len": seq_len, + "z": z, + } + + core_attn_out = forward_batch.attn_backend.forward( + q=None, + k=None, + v=None, + layer=None, + forward_batch=forward_batch, + **kwargs, + ) + + z_shape_og = z.shape + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(z_shape_og) + core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1) + + output, _ = self.out_proj(core_attn_out) + return output + + +class Qwen3HybridLinearDecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream) + + # Qwen3Next all layers are sparse and have no nextn now + self.is_layer_sparse = True + is_previous_layer_sparse = True + self.layer_id = layer_id + + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + if self.is_layer_sparse: + self.mlp = Qwen2MoeSparseMoeBlock( + layer_id=layer_id, + config=config, + quant_config=quant_config, + alt_stream=alt_stream, + ) + else: + self.mlp = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + if getattr( + config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) + ): + logger.warning_once( + "Using Gemma RMSNorm for input normalization and post attn normalization." + ) + self.input_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + **kwargs, + ): + forward_batch = kwargs.get("forward_batch", None) + + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + if not forward_batch.forward_mode.is_idle(): + hidden_states = self.linear_attn( + hidden_states, + forward_batch, + ) + # Fully Connected + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +class Qwen3HybridAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen3NextConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.attn_tp_rank = get_attention_tp_rank() + self.attn_tp_size = get_attention_tp_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % self.attn_tp_size == 0 + self.num_heads = self.total_num_heads // self.attn_tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= self.attn_tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.attn_tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = getattr(config, "rope_theta", 10000) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.partial_rotary_factor = config.partial_rotary_factor + self.layer_id = layer_id + + self.attn_output_gate = getattr(config, "attn_output_gate", True) + if self.attn_output_gate: + logger.warning_once("using attn output gate!") + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + rope_scaling=self.rope_scaling, + base=self.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads * (1 + self.attn_output_gate), + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=False, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=f"{prefix}.attn", + ) + + # Qwen3Next all layers are sparse and have no nextn now + self.is_layer_sparse = True + is_previous_layer_sparse = True + + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + if self.is_layer_sparse: + self.mlp = Qwen2MoeSparseMoeBlock( + layer_id=layer_id, + config=config, + quant_config=quant_config, + alt_stream=alt_stream, + ) + else: + self.mlp = Qwen2MoeMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + if getattr( + config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) + ): + logger.warning_once( + "Using Gemma RMSNorm for input normalization and post attn normalization." + ) + self.input_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = GemmaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + else: + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.q_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, + ) + + self.alt_stream = alt_stream + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.q_norm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.k_norm(k_by_head) + q = q_by_head.view(q.shape) + k = k_by_head.view(k.shape) + return q, k + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + + if self.attn_output_gate: + q_gate, k, v = qkv.split( + [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 + ) + orig_shape = q_gate.shape[:-1] + q_gate = q_gate.view(*orig_shape, self.num_heads, -1) + q, gate = torch.chunk(q_gate, 2, dim=-1) + q = q.reshape(*orig_shape, -1) + gate = gate.reshape(*orig_shape, -1) + else: + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + q, k = self._apply_qk_norm(q, k) + + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v, forward_batch) + + if self.attn_output_gate: + gate = torch.sigmoid(gate) + attn_output = attn_output * gate + + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + **kwargs: Any, + ): + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + if not forward_batch.forward_mode.is_idle(): + hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + ) + + # Fully Connected + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "attention": Qwen3HybridAttentionDecoderLayer, + "linear_attention": Qwen3HybridLinearDecoderLayer, +} + + +class Qwen3NextModel(nn.Module): + def __init__( + self, + config: Qwen3NextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + alt_stream = torch.cuda.Stream() if _is_cuda else None + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + enable_tp=not is_dp_attention_enabled(), + ) + + def get_layer(idx: int, prefix: str): + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]] + return layer_class( + config, + idx, + quant_config=quant_config, + prefix=prefix, + alt_stream=alt_stream, + ) + + self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + + if getattr( + config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) + ): + logger.warning_once("Using Gemma RMSNorm for final normalization.") + self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.infer_count = 0 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + # mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + layer_id=i, + positions=positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + if not forward_batch.forward_mode.is_idle(): + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + +class HybridLayerType(enum.Enum): + full_attention = "attention" + swa_attention = "swa_attention" + linear_attention = "linear_attention" + mamba2 = "mamba" + + +class Qwen3NextForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: Qwen3NextConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.pp_group = get_pp_group() + assert self.pp_group.is_first_rank and self.pp_group.is_last_rank + self.quant_config = quant_config + self.model = Qwen3NextModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + org_num_embeddings=config.vocab_size, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.lm_head = self.lm_head.float() + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False + ) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + + if is_mtp: + + if "mtp" not in name: + continue + + if name in [ + "mtp.fc.weight", + "mtp.pre_fc_norm_embedding.weight", + "mtp.pre_fc_norm_hidden.weight", + ]: + name = name.replace("mtp.", "") + else: + name = name.replace("mtp", "model") + + if not is_mtp and "mtp" in name: + continue + + if "rotary_emb.inv_freq" in name: + continue + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + # TODO(fix mtp loading) + if "mlp.experts" in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader") + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + # Skip loading extra bias for GPTQ models. + if ( + name.endswith(".bias") or name.endswith("_bias") + ) and name not in params_dict: + continue + param = params_dict[name] + + weight_loader = getattr(param, "weight_loader") + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # if is_pp_missing_parameter(name, self): + # continue + + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_experts, + num_groups=None, + ) + + +EntryClass = Qwen3NextForCausalLM diff --git a/python/sglang/srt/models/qwen3_next_mtp.py b/python/sglang/srt/models/qwen3_next_mtp.py new file mode 100644 index 000000000..4630ea300 --- /dev/null +++ b/python/sglang/srt/models/qwen3_next_mtp.py @@ -0,0 +1,117 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Inference-only Qwen3Next MTP Speculative Decoding.""" +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen3_moe import Qwen3MoeModel +from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel +from sglang.srt.utils import add_prefix + +logger = logging.getLogger(__name__) + + +class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + # if not set, model load will be broken in Qwen3NextForCausalLM load_weights() + self.pp_group = get_pp_group() + # self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP") + + # currently based on the provided ckpt, we: + # (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings + # (2) hardcode bias=False since not provided + self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + if getattr( + config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False) + ): + logger.warning_once( + "Using Gemma RMSNorm for input normalization and post attn normalization." + ) + RMSNorm_cls = GemmaRMSNorm + else: + RMSNorm_cls = RMSNorm + self.pre_fc_norm_embedding = RMSNorm_cls( + config.hidden_size, config.rms_norm_eps + ) + self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps) + config.num_hidden_layers = 1 + config.full_attention_interval = 1 + self.model = Qwen3NextModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + if input_embeds is None: + input_embeds = self.model.embed_tokens(input_ids) + + input_embeds = self.pre_fc_norm_embedding(input_embeds) + hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states) + hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1)) + + hidden_states = self.model( + input_ids, + positions, + forward_batch, + hidden_states, + ) + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False + ): + super().load_weights(weights, is_mtp=True) + + +EntryClass = [Qwen3NextForCausalLMMTP] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5dfce426e..fefdd547b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -95,6 +95,7 @@ ATTENTION_BACKEND_CHOICES = [ "trtllm_mla", "trtllm_mha", "dual_chunk_flash_attn", + "hybrid_linear_attn", # AMD specific "aiter", "wave", @@ -390,6 +391,10 @@ class ServerArgs: enable_pdmux: bool = False sm_group_num: int = 3 + # Mamba cache + max_mamba_cache_size: Optional[int] = None + mamba_ssm_dtype: str = "float32" + # Deprecated arguments enable_ep_moe: bool = False enable_deepep_moe: bool = False @@ -835,6 +840,8 @@ class ServerArgs: os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = ( "1" if self.enable_torch_compile else "0" ) + os.environ["SGLANG_MAMBA_SSM_DTYPE"] = self.mamba_ssm_dtype + # Set env var before grammar backends init os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = ( "1" if self.disable_outlines_disk_cache else "0" @@ -1714,7 +1721,20 @@ class ServerArgs: default=ServerArgs.moe_dense_tp_size, help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.", ) - + # Mamba Cache + parser.add_argument( + "--max-mamba-cache-size", + type=int, + default=ServerArgs.max_mamba_cache_size, + help="It is used for mamba cache memory static allocation.", + ) + parser.add_argument( + "--mamba-ssm-dtype", + type=str, + default=ServerArgs.mamba_ssm_dtype, + choices=["float32", "bfloat16"], + help="It is used to tune mamba ssm dtype", + ) # Hierarchical cache parser.add_argument( "--enable-hierarchical-cache", diff --git a/python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py new file mode 100644 index 000000000..bf8d462aa --- /dev/null +++ b/python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py @@ -0,0 +1,195 @@ +import bisect +from typing import TYPE_CHECKING, Callable + +import torch +import torch.nn.functional as F + +from sglang.srt.layers.attention.fla.fused_recurrent import ( + fused_recurrent_gated_delta_rule_update, +) +from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn +from sglang.srt.model_executor.cuda_graph_runner import ( + CUDA_GRAPH_CAPTURE_FAILED_MSG, + CudaGraphRunner, + get_batch_sizes_to_capture, + get_global_graph_memory_pool, + model_capture_mode, + set_global_graph_memory_pool, +) +from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer + +if TYPE_CHECKING: + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + +class MambaStateUpdateCudaGraphRunner: + def __init__(self, eagle_worker: "EAGLEWorker"): + self.eagle_worker = eagle_worker + model_runner = eagle_worker.target_worker.model_runner + self.model_runner = model_runner + self.attn_backend = model_runner.attn_backend.attn_backend_list[1] + self.req_to_token_pool = self.attn_backend.req_to_token_pool + + self.graphs = {} + self.output_buffers = {} + self.graph_input_buffer = None + self.stream = torch.cuda.Stream() + self.model = model_runner.model + + self.enable_profile_cuda_graph = ( + model_runner.server_args.enable_profile_cuda_graph + ) + self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) + self.max_bs = self.capture_bs[-1] + + self.init_cuda_graph_state() + # Capture + try: + with model_capture_mode(): + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" + ) + + def init_cuda_graph_state(self): + self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache + self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2] + num_mamba_layers = self.mamba_cache[0].shape[0] + conv_dtype = torch.bfloat16 + conv_shape = self.mamba_cache[0].shape[2] + total_token_number = self.max_accepted_tokens * self.max_bs + self.mixed_qkv_cache = torch.empty( + size=( + num_mamba_layers, + total_token_number, + conv_shape, + ), + dtype=conv_dtype, + device="cuda", + ) + self.query_start_loc = torch.zeros( + (self.max_bs + 1,), dtype=torch.int32, device="cuda" + ) + self.state_indices = torch.zeros( + (self.max_bs + 1,), dtype=torch.int32, device="cuda" + ) + self.has_initial_states = torch.ones( + self.max_bs, dtype=torch.bool, device="cuda" + ) + + def capture(self): + CudaGraphRunner.capture(self) + + def capture_one_batch_size(self, bs: int, forward: Callable): + """ + Capture CUDA Graph for a typical workload + """ + graph = torch.cuda.CUDAGraph() + stream = self.stream + total_token_number = bs * self.max_accepted_tokens + mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number] + + query_start_loc = self.query_start_loc[: bs + 1] + state_indices = self.state_indices[:bs] + has_initial_states = self.has_initial_states[:bs] + + mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers() + conv_states = mamba_caches[0] + mamba_map = self.req_to_token_pool.mamba_map + + def run_once(): + for i in range(len(self.model.model.layers)): + layer = self.model.model.layers[i] + if not isinstance(layer, Qwen3HybridLinearDecoderLayer): + continue + conv_weights = layer.linear_attn.conv1d.weight.view( + layer.linear_attn.conv1d.weight.size(0), + layer.linear_attn.conv1d.weight.size(2), + ) + layer_id = mamba_map[i] + + causal_conv1d_fn( + mixed_qkvs[layer_id].transpose(0, 1), + conv_weights, + layer.linear_attn.conv1d.bias, + activation=layer.linear_attn.activation, + conv_states=conv_states[layer_id], + has_initial_state=has_initial_states, + cache_indices=state_indices, + query_start_loc=query_start_loc, + ) + + return None + + for _ in range(2): + torch.cuda.synchronize() + self.model_runner.tp_group.barrier() + + run_once() + + with torch.cuda.graph( + graph, pool=get_global_graph_memory_pool(), stream=stream + ): + out = run_once() + + set_global_graph_memory_pool(graph.pool()) + return graph, out + + def can_run(self, accepted_length): + bs = accepted_length.shape[0] + return bs <= self.max_bs + + def replay_repare(self, accepted_length): + request_number = accepted_length.shape[0] + # QQ: step = spec num_draft token num + num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2] + query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype) + query_start_loc = torch.cat( + [ + torch.zeros( + 1, + dtype=query_start_loc.dtype, + device=query_start_loc.device, + ), + query_start_loc, + ] + ) + mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze( + 0 + ) < accepted_length.unsqueeze(1) + + state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[ + :request_number + ] + mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers() + + _, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches + mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask] + self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs) + self.query_start_loc[: request_number + 1] = query_start_loc + self.query_start_loc[request_number + 1 :] = self.query_start_loc[ + request_number + ] + self.state_indices[:request_number] = state_indices_tensor + self.state_indices[request_number:] = -1 + valid_mask = accepted_length > 0 + if intermediate_state_cache is not None: + last_steps = (accepted_length - 1).to(torch.int64) + valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) + + ssm_states[:, valid_state_indices, :] = intermediate_state_cache[ + :, valid_state_indices, last_steps + ].to(ssm_states.dtype) + + def replay(self, accepted_length): + # batch_size and num_seqs can be different in case there are finished examples + # in the batch, which will not be counted as num_seqs + raw_bs = accepted_length.shape[0] + index = bisect.bisect_left(self.capture_bs, raw_bs) + + bs = self.capture_bs[index] + + self.replay_repare(accepted_length) + # Replay + self.graphs[bs].replay() diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 3ca2f464e..3ec32a0a2 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -214,6 +214,7 @@ class EAGLEWorker(TpModelWorker): "triton": self._create_triton_decode_backend, "aiter": self._create_aiter_decode_backend, "fa3": self._create_fa3_decode_backend, + "hybrid_linear_attn": self._create_fa3_decode_backend, "flashmla": self._create_flashmla_decode_backend, "trtllm_mha": self._create_trtllm_mha_decode_backend, "trtllm_mla": self._create_trtllm_mla_decode_backend, @@ -231,6 +232,7 @@ class EAGLEWorker(TpModelWorker): "triton": self._create_triton_prefill_backend, "aiter": self._create_aiter_prefill_backend, "fa3": self._create_fa3_prefill_backend, + "hybrid_linear_attn": self._create_fa3_prefill_backend, "trtllm_mha": self._create_trtllm_mha_prefill_backend, "trtllm_mla": self._create_trtllm_mla_prefill_backend, } @@ -405,6 +407,15 @@ class EAGLEWorker(TpModelWorker): f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) + if self.target_worker.model_runner.is_hybrid_gdn: + from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import ( + MambaStateUpdateCudaGraphRunner, + ) + + self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner( + self + ) + @property def draft_model_runner(self): return self.model_runner @@ -826,6 +837,24 @@ class EAGLEWorker(TpModelWorker): ] logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices] + # QQ: can be optimized + if self.target_worker.model_runner.is_hybrid_gdn: + # res.draft_input.accept_length is on GPU but may be empty for last verify? + accepted_length = ( + torch.tensor( + res.accept_length_per_req_cpu, + device=logits_output.hidden_states.device, + dtype=torch.int32, + ) + + 1 + ) + if self.cuda_graph_runner_for_target_verify.can_run(accepted_length): + self.cuda_graph_runner_for_target_verify.replay(accepted_length) + else: + self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( + accepted_length, self.target_worker.model_runner.model + ) + if batch.return_logprob: self.add_logprob_values(batch, res, logits_output)