Qwen3-Next support (#10233)
Co-authored-by: cao1zhg <114661107+cao1zhg@users.noreply.github.com> Co-authored-by: ispobock <ispobaoke@gmail.com> Co-authored-by: Binyao Jiang <byjiang1996@gmail.com> Co-authored-by: hebiao064 <hebiaobuaa@gmail.com> Co-authored-by: Lifu Huang <lifu.hlf@gmail.com> Co-authored-by: qingquansong <ustcsqq@gmail.com> Co-authored-by: Yaoyao Ding <dingyaoyao.cs@gmail.com> Co-authored-by: Ke Bao <ISPObaoke@163.com> Co-authored-by: Minglei Zhu <mingleizhu1122@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
||||
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
||||
from sglang.srt.configs.step3_vl import (
|
||||
Step3TextConfig,
|
||||
Step3VisionEncoderConfig,
|
||||
@@ -24,4 +25,5 @@ __all__ = [
|
||||
"Step3VLConfig",
|
||||
"Step3TextConfig",
|
||||
"Step3VisionEncoderConfig",
|
||||
"Qwen3NextConfig",
|
||||
]
|
||||
|
||||
@@ -147,6 +147,9 @@ class ModelConfig:
|
||||
):
|
||||
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
||||
|
||||
if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
|
||||
self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
|
||||
|
||||
# Check model type
|
||||
self.is_generation = is_generation_model(
|
||||
self.hf_config.architectures, is_embedding
|
||||
|
||||
326
python/sglang/srt/configs/qwen3_next.py
Normal file
326
python/sglang/srt/configs/qwen3_next.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3Hybrid model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# NOTE: HybridLayerType
|
||||
class HybridLayerType(enum.Enum):
|
||||
full_attention = "attention"
|
||||
swa_attention = "swa_attention"
|
||||
linear_attention = "linear_attention"
|
||||
mamba2 = "mamba"
|
||||
|
||||
|
||||
class Qwen3NextConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
|
||||
Qwen3-Next model according to the specified arguments, defining the model architecture.
|
||||
Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids`.
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 5632):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 2):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
||||
hidden_act (`str`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`List[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
partial_rotary_factor (`float`, *optional*, defaults to 0.25):
|
||||
Percentage of the query and keys which will have rotary embedding.
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
head_dim (`int`, *optional*, defaults to 256):
|
||||
Projection weights dimension in multi-head attention.
|
||||
linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
|
||||
Kernel size of the convolution used in linear attention layers.
|
||||
linear_key_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of each key head in linear attention.
|
||||
linear_value_head_dim (`int`, *optional*, defaults to 128):
|
||||
Dimension of each value head in linear attention.
|
||||
linear_num_key_heads (`int`, *optional*, defaults to 16):
|
||||
Number of key heads used in linear attention layers.
|
||||
linear_num_value_heads (`int`, *optional*, defaults to 32):
|
||||
Number of value heads used in linear attention layers.
|
||||
decoder_sparse_step (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 512):
|
||||
Intermediate size of the routed expert.
|
||||
shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
|
||||
Intermediate size of the shared expert.
|
||||
num_experts_per_tok (`int`, *optional*, defaults to 10):
|
||||
Number of selected experts.
|
||||
num_experts (`int`, *optional*, defaults to 512):
|
||||
Number of routed experts.
|
||||
norm_topk_prob (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the topk probabilities.
|
||||
output_router_logits (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the router logits should be returned by the model. Enabling this will also
|
||||
allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
|
||||
router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
|
||||
The aux loss factor for the total loss.
|
||||
mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
|
||||
Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
|
||||
The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
|
||||
If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
|
||||
layer_types (`list[str]`, *optional*, defaults to None):
|
||||
Types of each layer (attention or linear).
|
||||
|
||||
```python
|
||||
>>> from transformers import Qwen3NextModel, Qwen3NextConfig
|
||||
|
||||
>>> # Initializing a Qwen3Next style configuration
|
||||
>>> configuration = Qwen3NextConfig()
|
||||
|
||||
>>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
|
||||
>>> model = Qwen3NextModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```
|
||||
"""
|
||||
|
||||
model_type = "qwen3_next"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=151936,
|
||||
hidden_size=2048,
|
||||
intermediate_size=5632,
|
||||
num_hidden_layers=48,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
partial_rotary_factor=0.25,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
head_dim=256,
|
||||
linear_conv_kernel_dim=4,
|
||||
linear_key_head_dim=128,
|
||||
linear_value_head_dim=128,
|
||||
linear_num_key_heads=16,
|
||||
linear_num_value_heads=32,
|
||||
decoder_sparse_step=1,
|
||||
moe_intermediate_size=512,
|
||||
shared_expert_intermediate_size=512,
|
||||
num_experts_per_tok=10,
|
||||
num_experts=512,
|
||||
norm_topk_prob=True,
|
||||
output_router_logits=False,
|
||||
router_aux_loss_coef=0.001,
|
||||
mlp_only_layers=[],
|
||||
layer_types=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.partial_rotary_factor = partial_rotary_factor
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
self.head_dim = head_dim
|
||||
rope_config_validation(self)
|
||||
|
||||
# linear attention (gdn now part)
|
||||
self.linear_conv_kernel_dim = linear_conv_kernel_dim
|
||||
self.linear_key_head_dim = linear_key_head_dim
|
||||
self.linear_value_head_dim = linear_value_head_dim
|
||||
self.linear_num_key_heads = linear_num_key_heads
|
||||
self.linear_num_value_heads = linear_num_value_heads
|
||||
|
||||
# MoE arguments
|
||||
self.decoder_sparse_step = decoder_sparse_step
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.shared_expert_intermediate_size = shared_expert_intermediate_size
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.num_experts = num_experts
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = mlp_only_layers
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
layer_type_list = []
|
||||
|
||||
for l in range(self.num_hidden_layers):
|
||||
if (l + 1) % self.full_attention_interval == 0:
|
||||
layer_type_list.append(HybridLayerType.full_attention.value)
|
||||
else:
|
||||
layer_type_list.append(HybridLayerType.linear_attention.value)
|
||||
|
||||
return layer_type_list
|
||||
|
||||
@property
|
||||
def linear_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i, type_value in enumerate(self.layers_block_type)
|
||||
if type_value == HybridLayerType.linear_attention.value
|
||||
]
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
return [
|
||||
i
|
||||
for i, type_value in enumerate(self.layers_block_type)
|
||||
if type_value == HybridLayerType.full_attention.value
|
||||
]
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_attention_tp_size()
|
||||
conv_dim = (
|
||||
self.linear_key_head_dim * self.linear_num_key_heads * 2
|
||||
+ self.linear_value_head_dim * self.linear_num_value_heads
|
||||
)
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.linear_conv_kernel_dim - 1,
|
||||
)
|
||||
|
||||
temporal_state_shape = (
|
||||
divide(self.linear_num_value_heads, world_size),
|
||||
self.linear_key_head_dim,
|
||||
self.linear_value_head_dim,
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
)
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
@@ -42,6 +42,7 @@ from sglang.srt.configs import (
|
||||
KimiVLConfig,
|
||||
LongcatFlashConfig,
|
||||
MultiModalityConfig,
|
||||
Qwen3NextConfig,
|
||||
Step3VLConfig,
|
||||
)
|
||||
from sglang.srt.configs.internvl import InternVLChatConfig
|
||||
@@ -58,6 +59,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
InternVLChatConfig.model_type: InternVLChatConfig,
|
||||
Step3VLConfig.model_type: Step3VLConfig,
|
||||
LongcatFlashConfig.model_type: LongcatFlashConfig,
|
||||
Qwen3NextConfig.model_type: Qwen3NextConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
581
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
Normal file
581
python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
Normal file
@@ -0,0 +1,581 @@
|
||||
from dataclasses import astuple, dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.fla.chunk import chunk_gated_delta_rule
|
||||
from sglang.srt.layers.attention.fla.fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
mamba_cache_indices: torch.Tensor
|
||||
|
||||
|
||||
class MambaAttnBackend(AttentionBackend):
|
||||
"""Attention backend using Mamba kernel."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.pad_slot_id = -1 # Default pad slot id
|
||||
self.device = model_runner.device
|
||||
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
self.state_indices_list = []
|
||||
self.query_start_loc_list = []
|
||||
|
||||
@classmethod
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor:
|
||||
"""Cache torch.arange tensors for common batch sizes to avoid repeated allocation."""
|
||||
device = torch.device(device_str)
|
||||
return torch.arange(0, bs + 1, dtype=torch.int32, device=device)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
bs = forward_batch.batch_size
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
query_start_loc = self._get_cached_arange(bs, str(self.device))
|
||||
elif forward_batch.forward_mode.is_extend():
|
||||
if forward_batch.forward_mode.is_target_verify():
|
||||
query_start_loc = torch.arange(
|
||||
0,
|
||||
forward_batch.input_ids.shape[0] + 1,
|
||||
step=forward_batch.spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device=forward_batch.input_ids.device,
|
||||
)
|
||||
else:
|
||||
query_start_loc = torch.empty(
|
||||
(bs + 1,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
query_start_loc[:bs] = forward_batch.extend_start_loc
|
||||
query_start_loc[bs] = (
|
||||
forward_batch.extend_start_loc[-1]
|
||||
+ forward_batch.extend_seq_lens[-1]
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode=}")
|
||||
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
|
||||
forward_batch.req_pool_indices
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(max_bs):
|
||||
self.state_indices_list.append(
|
||||
torch.full((i + 1,), self.pad_slot_id, dtype=torch.int32, device="cuda")
|
||||
)
|
||||
self.query_start_loc_list.append(
|
||||
torch.empty((i + 2,), dtype=torch.int32, device="cuda")
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
|
||||
elif forward_mode.is_target_verify():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
torch.arange(
|
||||
0,
|
||||
bs * spec_info.draft_token_num + 1,
|
||||
step=spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
|
||||
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
query_start_loc=self.query_start_loc_list[bs - 1],
|
||||
mamba_cache_indices=self.state_indices_list[bs - 1],
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
num_padding = torch.count_nonzero(
|
||||
seq_lens_cpu == self.get_cuda_graph_seq_len_fill_value()
|
||||
)
|
||||
# Make sure forward metadata is correctly handled for padding reqs
|
||||
req_pool_indices[bs - num_padding :] = 0
|
||||
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
|
||||
mamba_indices[bs - num_padding :] = -1
|
||||
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda"))
|
||||
if num_padding > 0:
|
||||
self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding
|
||||
elif forward_mode.is_target_verify():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
torch.arange(
|
||||
0,
|
||||
bs * spec_info.draft_token_num + 1,
|
||||
step=spec_info.draft_token_num,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
if num_padding > 0:
|
||||
self.query_start_loc_list[bs - 1][bs - num_padding :] = (
|
||||
bs - num_padding
|
||||
) * spec_info.draft_token_num
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
query_start_loc=self.query_start_loc_list[bs - 1],
|
||||
mamba_cache_indices=self.state_indices_list[bs - 1],
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1 # Mamba attn does not use seq lens to index kv cache
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
mixed_qkv = kwargs["mixed_qkv"]
|
||||
conv_weights = kwargs["conv_weights"]
|
||||
bias = kwargs["bias"]
|
||||
activation = kwargs["activation"]
|
||||
key_dim = kwargs["key_dim"]
|
||||
value_dim = kwargs["value_dim"]
|
||||
attn_tp_size = kwargs["attention_tp_size"]
|
||||
head_k_dim = kwargs["head_k_dim"]
|
||||
head_v_dim = kwargs["head_v_dim"]
|
||||
a = kwargs["a"]
|
||||
b = kwargs["b"]
|
||||
A_log = kwargs["A_log"]
|
||||
dt_bias = kwargs["dt_bias"]
|
||||
layer_id = kwargs["layer_id"]
|
||||
|
||||
conv_states, ssm_states = self.req_to_token_pool.get_mamba_params(layer_id)
|
||||
query_start_loc = self.forward_metadata.query_start_loc
|
||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||
|
||||
mixed_qkv = causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_states,
|
||||
conv_weights,
|
||||
bias,
|
||||
activation,
|
||||
conv_state_indices=cache_indices,
|
||||
)
|
||||
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[
|
||||
key_dim // attn_tp_size,
|
||||
key_dim // attn_tp_size,
|
||||
value_dim // attn_tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
# Reshape from [l, h*d] to [1, l, h, d]
|
||||
seq_len = query.shape[0]
|
||||
num_heads = query.shape[1] // head_k_dim
|
||||
query = query.view(1, seq_len, num_heads, head_k_dim)
|
||||
key = key.view(1, seq_len, num_heads, head_k_dim)
|
||||
value = value.view(1, seq_len, value.shape[1] // head_v_dim, head_v_dim)
|
||||
|
||||
core_attn_out = fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=A_log,
|
||||
dt_bias=dt_bias,
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
a=a,
|
||||
b=b,
|
||||
initial_state_source=ssm_states,
|
||||
initial_state_indices=cache_indices,
|
||||
cu_seqlens=query_start_loc,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
softplus_beta=1.0,
|
||||
softplus_threshold=20.0,
|
||||
)
|
||||
|
||||
return core_attn_out
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
mixed_qkv = kwargs["mixed_qkv"]
|
||||
conv_weights = kwargs["conv_weights"]
|
||||
bias = kwargs["bias"]
|
||||
activation = kwargs["activation"]
|
||||
key_dim = kwargs["key_dim"]
|
||||
value_dim = kwargs["value_dim"]
|
||||
attn_tp_size = kwargs["attention_tp_size"]
|
||||
head_k_dim = kwargs["head_k_dim"]
|
||||
head_v_dim = kwargs["head_v_dim"]
|
||||
a = kwargs["a"]
|
||||
b = kwargs["b"]
|
||||
A_log = kwargs["A_log"]
|
||||
dt_bias = kwargs["dt_bias"]
|
||||
layer_id = kwargs["layer_id"]
|
||||
seq_len = kwargs["seq_len"]
|
||||
|
||||
is_target_verify = forward_batch.forward_mode.is_target_verify()
|
||||
|
||||
query_start_loc = self.forward_metadata.query_start_loc
|
||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||
|
||||
if is_target_verify:
|
||||
(
|
||||
conv_states,
|
||||
ssm_states,
|
||||
mixed_qkv_cache,
|
||||
intermediate_state_cache,
|
||||
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
||||
mixed_qkv_cache[cache_indices] = mixed_qkv.view(
|
||||
(-1,) + mixed_qkv_cache.shape[1:]
|
||||
).clone()
|
||||
has_initial_states = torch.ones(
|
||||
seq_len // forward_batch.spec_info.draft_token_num,
|
||||
dtype=torch.bool,
|
||||
device=forward_batch.input_ids.device,
|
||||
)
|
||||
conv_states_to_use = conv_states.clone()
|
||||
else:
|
||||
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
||||
layer_id
|
||||
)
|
||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||
conv_states_to_use = conv_states
|
||||
mixed_qkv = causal_conv1d_fn(
|
||||
mixed_qkv.transpose(0, 1),
|
||||
conv_weights,
|
||||
bias,
|
||||
activation=activation,
|
||||
conv_states=conv_states_to_use,
|
||||
has_initial_state=has_initial_states,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc,
|
||||
).transpose(0, 1)[:seq_len]
|
||||
|
||||
key_split_dim = key_dim // attn_tp_size
|
||||
value_split_dim = value_dim // attn_tp_size
|
||||
|
||||
query, key, value = torch.split(
|
||||
mixed_qkv,
|
||||
[key_split_dim, key_split_dim, value_split_dim],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
actual_seq_len = query.shape[0]
|
||||
num_heads = query.shape[1] // head_k_dim
|
||||
num_value_heads = value.shape[1] // head_v_dim
|
||||
|
||||
query = query.view(1, actual_seq_len, num_heads, head_k_dim)
|
||||
key = key.view(1, actual_seq_len, num_heads, head_k_dim)
|
||||
value = value.view(1, actual_seq_len, num_value_heads, head_v_dim)
|
||||
|
||||
beta = b.sigmoid()
|
||||
g = fused_gdn_gating(A_log, a, dt_bias)
|
||||
|
||||
g = g.unsqueeze(0)
|
||||
beta = beta.unsqueeze(0)
|
||||
|
||||
if is_target_verify:
|
||||
core_attn_out = fused_recurrent_gated_delta_rule_update(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state_source=ssm_states,
|
||||
initial_state_indices=cache_indices,
|
||||
cu_seqlens=query_start_loc,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
disable_state_update=True,
|
||||
intermediate_states_buffer=intermediate_state_cache,
|
||||
cache_steps=forward_batch.spec_info.draft_token_num,
|
||||
)
|
||||
else:
|
||||
recurrent_state = ssm_states[cache_indices]
|
||||
core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
g=g,
|
||||
beta=beta,
|
||||
initial_state=recurrent_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=query_start_loc,
|
||||
head_first=False,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
last_recurrent_state = last_recurrent_state.to(ssm_states.dtype, copy=False)
|
||||
ssm_states[cache_indices] = last_recurrent_state
|
||||
|
||||
return core_attn_out
|
||||
|
||||
|
||||
class HybridLinearAttnBackend(AttentionBackend):
|
||||
"""Support different backends for prefill and decode."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_attn_backend: AttentionBackend,
|
||||
linear_attn_backend: AttentionBackend,
|
||||
full_attn_layers: list[int],
|
||||
):
|
||||
self.full_attn_layers = full_attn_layers
|
||||
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
for attn_backend in self.attn_backend_list:
|
||||
attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for attn_backend in self.attn_backend_list:
|
||||
attn_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for attn_backend in self.attn_backend_list:
|
||||
attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
for attn_backend in self.attn_backend_list:
|
||||
attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
seq_lens_cpu,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
||||
if layer_id in self.full_attn_layers:
|
||||
return self.attn_backend_list[0].forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
return self.attn_backend_list[1].forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
||||
if layer_id in self.full_attn_layers:
|
||||
return self.attn_backend_list[0].forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
return self.attn_backend_list[1].forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Run forward on an attention layer."""
|
||||
if forward_batch.forward_mode.is_idle():
|
||||
if layer is None:
|
||||
return torch.empty_like(kwargs["z"])
|
||||
return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
return self.forward_decode(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache=save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self.forward_extend(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
layer,
|
||||
forward_batch,
|
||||
save_kv_cache=save_kv_cache,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
||||
request_number = accepted_length.shape[0]
|
||||
# QQ: step = spec num_draft token num
|
||||
num_draft_tokens = (
|
||||
self.attn_backend_list[1]
|
||||
.req_to_token_pool.mamba_pool.mamba_cache[2]
|
||||
.shape[2]
|
||||
)
|
||||
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
|
||||
query_start_loc = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
1,
|
||||
dtype=query_start_loc.dtype,
|
||||
device=query_start_loc.device,
|
||||
),
|
||||
query_start_loc,
|
||||
]
|
||||
)
|
||||
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
|
||||
0
|
||||
) < accepted_length.unsqueeze(1)
|
||||
|
||||
state_indices_tensor = self.attn_backend_list[
|
||||
1
|
||||
].forward_metadata.mamba_cache_indices[:request_number]
|
||||
|
||||
mamba_caches = self.attn_backend_list[
|
||||
1
|
||||
].req_to_token_pool.get_mamba_params_all_layers()
|
||||
|
||||
conv_states, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
|
||||
|
||||
mixed_qkvs = mix_qkv_cache[:, state_indices_tensor][:, mask]
|
||||
|
||||
mamba_map = self.attn_backend_list[1].req_to_token_pool.mamba_map
|
||||
|
||||
has_initial_states = torch.ones(
|
||||
request_number, dtype=torch.bool, device=accepted_length.device
|
||||
)
|
||||
|
||||
# Batch SSM state updates (outside the loop for efficiency)
|
||||
valid_mask = accepted_length > 0
|
||||
if intermediate_state_cache is not None:
|
||||
last_steps = (accepted_length - 1).to(torch.int64)
|
||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
||||
|
||||
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
||||
:, valid_state_indices, last_steps
|
||||
].to(ssm_states.dtype)
|
||||
|
||||
# For loop conv state updates (can be optimized)
|
||||
for i in range(len(model.model.layers)):
|
||||
layer = model.model.layers[i]
|
||||
if isinstance(layer, Qwen3HybridLinearDecoderLayer):
|
||||
conv_weights = layer.linear_attn.conv1d.weight.view(
|
||||
layer.linear_attn.conv1d.weight.size(0),
|
||||
layer.linear_attn.conv1d.weight.size(2),
|
||||
)
|
||||
|
||||
layer_id = mamba_map[i]
|
||||
conv_state = conv_states[layer_id]
|
||||
mixed_qkv = mixed_qkvs[layer_id]
|
||||
|
||||
_ = causal_conv1d_fn(
|
||||
mixed_qkv.transpose(0, 1),
|
||||
conv_weights,
|
||||
layer.linear_attn.conv1d.bias,
|
||||
activation=layer.linear_attn.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states,
|
||||
cache_indices=state_indices_tensor,
|
||||
query_start_loc=query_start_loc,
|
||||
)
|
||||
128
python/sglang/srt/layers/attention/mamba/causal_conv1d.py
Normal file
128
python/sglang/srt/layers/attention/mamba/causal_conv1d.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel import causal_conv1d_fwd
|
||||
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
query_start_loc: Optional[torch.Tensor] = None,
|
||||
cache_indices: Optional[torch.Tensor] = None,
|
||||
has_initial_state: Optional[torch.Tensor] = None,
|
||||
conv_states: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = "silu",
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
|
||||
sequences are concatenated from left to right for varlen
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended by 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
indicates the corresponding state index,
|
||||
like so: conv_state = conv_states[cache_indices[batch_id]]
|
||||
has_initial_state: (batch) bool
|
||||
indicates whether should the kernel take the current state as initial
|
||||
state for the calculations
|
||||
conv_states: (...,dim,width - 1) itype
|
||||
updated inplace if provided
|
||||
activation: either None or "silu" or "swish"
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
|
||||
|
||||
out: (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError("activation must be None, silu, or swish")
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
bias = bias.contiguous() if bias is not None else None
|
||||
|
||||
causal_conv1d_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
conv_states,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
activation in ["silu", "swish"],
|
||||
pad_slot_id,
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def causal_conv1d_update(
|
||||
x: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
activation: Optional[str] = None,
|
||||
cache_seqlens: Optional[torch.Tensor] = None,
|
||||
conv_state_indices: Optional[torch.Tensor] = None,
|
||||
pad_slot_id: int = PAD_SLOT_ID,
|
||||
):
|
||||
"""
|
||||
x: (batch, dim) or (batch, dim, seqlen)
|
||||
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
||||
weight: (dim, width)
|
||||
bias: (dim,)
|
||||
cache_seqlens: (batch,), dtype int32.
|
||||
If not None, the conv_state is treated as a circular buffer.
|
||||
The conv_state will be updated by copying x to the conv_state
|
||||
starting at the index
|
||||
@cache_seqlens % state_len.
|
||||
conv_state_indices: (batch,), dtype int32
|
||||
If not None, the conv_state is a larger tensor along the batch dim,
|
||||
and we are selecting the batch coords specified by conv_state_indices.
|
||||
Useful for a continuous batching scenario.
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: (batch, dim) or (batch, dim, seqlen)
|
||||
"""
|
||||
if activation not in [None, "silu", "swish"]:
|
||||
raise NotImplementedError(
|
||||
f"activation must be None, silu, or swish, actual: {activation}"
|
||||
)
|
||||
activation_val = activation in ["silu", "swish"]
|
||||
unsqueeze = x.dim() == 2
|
||||
if unsqueeze:
|
||||
x = x.unsqueeze(-1)
|
||||
causal_conv1d_update_kernel(
|
||||
x,
|
||||
conv_state,
|
||||
weight,
|
||||
bias,
|
||||
activation_val,
|
||||
cache_seqlens,
|
||||
conv_state_indices,
|
||||
pad_slot_id,
|
||||
)
|
||||
if unsqueeze:
|
||||
x = x.squeeze(-1)
|
||||
return x
|
||||
64
python/sglang/srt/layers/attention/mamba/mamba.py
Normal file
64
python/sglang/srt/layers/attention/mamba/mamba.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
|
||||
|
||||
|
||||
def mamba_v2_sharded_weight_loader(
|
||||
shard_spec: List[Tuple[int, int, float]],
|
||||
tp_size: int,
|
||||
tp_rank: int,
|
||||
) -> LoaderFunction:
|
||||
"""Create a weight loader for mamba v2. This ensures that the projections
|
||||
are correctly sharded so that they can be split into x, B, C. It also
|
||||
ensures the the all the groups corresponding to a head shard is placed
|
||||
together with it.
|
||||
"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
|
||||
# - track boundary of (sharded) param, and loaded_weight, respectively
|
||||
boundary, loaded_boundary = 0, 0
|
||||
|
||||
# - iterate over the shard specs
|
||||
for full_dim, extra, duplicate_groups in shard_spec:
|
||||
# - full dim is the model dim (before TP).
|
||||
# - extra > 0, means there is expected overall increase
|
||||
# of dimensions. This is so because of replication.
|
||||
# - ratio is used map the tp_rank to the actual shard
|
||||
# rank. This is useful when there is replication of
|
||||
# groups to accompany head shards.
|
||||
|
||||
# - size of the loaded shard
|
||||
shard_size = full_dim // tp_size
|
||||
|
||||
# - compute the rank into the loaded shard.
|
||||
# - if there is replication, different TP shards will
|
||||
# take from the same rank.
|
||||
# NOTE: currently we only support duplication
|
||||
# in the case where num_groups == 1
|
||||
rank = 0 if duplicate_groups else tp_rank
|
||||
|
||||
# - leftmost boundary index into loaded weight.
|
||||
loaded_skip = rank * shard_size
|
||||
loaded_start_idx = loaded_boundary + loaded_skip
|
||||
|
||||
# - take these many dims from the loaded weight.
|
||||
take = min(shard_size, full_dim - extra - loaded_skip)
|
||||
|
||||
# - always shard on dim 0
|
||||
# - the ignore is for a mundane mypy error as it does not
|
||||
# seem to handle slices well.
|
||||
# https://github.com/python/mypy/issues/2410
|
||||
param.data[
|
||||
boundary : (boundary + take), ... # type: ignore[misc]
|
||||
] = loaded_weight[
|
||||
loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
|
||||
] # type: ignore[misc]
|
||||
|
||||
# move indexing boundaries
|
||||
boundary += shard_size
|
||||
loaded_boundary += full_dim - extra
|
||||
|
||||
return loader
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
{
|
||||
"1": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"2": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"4": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"8": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 4
|
||||
},
|
||||
"16": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"24": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"32": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"48": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"64": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"96": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"128": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"256": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"512": {
|
||||
"BLOCK_SIZE_M": 16,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"1024": {
|
||||
"BLOCK_SIZE_M": 32,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3
|
||||
},
|
||||
"1536": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 16,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"2048": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2
|
||||
},
|
||||
"3072": {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": 64,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 64,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2
|
||||
},
|
||||
"4096": {
|
||||
"BLOCK_SIZE_M": 128,
|
||||
"BLOCK_SIZE_N": 128,
|
||||
"BLOCK_SIZE_K": 64,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 3
|
||||
}
|
||||
}
|
||||
@@ -38,7 +38,7 @@ import threading
|
||||
from enum import Enum, auto
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -59,7 +59,7 @@ from sglang.srt.mem_cache.allocator import (
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import TimeStats
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
@@ -962,8 +962,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def alloc_req_slots(self, num_reqs: int):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
|
||||
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
|
||||
else:
|
||||
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
|
||||
if req_pool_indices is None:
|
||||
raise RuntimeError(
|
||||
"alloc_req_slots runs out of memory. "
|
||||
@@ -1138,7 +1141,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Allocate req slots
|
||||
bs = len(self.reqs)
|
||||
req_pool_indices = self.alloc_req_slots(bs)
|
||||
req_pool_indices = self.alloc_req_slots(bs, self.reqs)
|
||||
|
||||
# Init tensors
|
||||
reqs = self.reqs
|
||||
|
||||
@@ -1540,7 +1540,12 @@ class Scheduler(
|
||||
chunked_req_to_exclude.add(self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
if self.tp_worker.worker.model_runner.is_hybrid_gdn:
|
||||
self.req_to_token_pool.free(
|
||||
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
||||
)
|
||||
else:
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||
if self.last_batch.chunked_req is not None:
|
||||
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
||||
|
||||
@@ -102,6 +102,204 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class MambaPool:
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
conv_dtype: torch.dtype,
|
||||
ssm_dtype: torch.dtype,
|
||||
num_mamba_layers: int,
|
||||
conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int],
|
||||
device: str,
|
||||
speculative_num_draft_tokens: Optional[int] = None,
|
||||
):
|
||||
conv_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
||||
dtype=conv_dtype,
|
||||
device=device,
|
||||
)
|
||||
temporal_state = torch.zeros(
|
||||
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
||||
dtype=ssm_dtype,
|
||||
device=device,
|
||||
)
|
||||
if speculative_num_draft_tokens is not None:
|
||||
mixed_qkv_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
conv_state_shape[0],
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
# Cache intermediate SSM states per draft token during target verify
|
||||
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
||||
intermediate_ssm_state_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
size + 1,
|
||||
speculative_num_draft_tokens,
|
||||
temporal_state_shape[0],
|
||||
temporal_state_shape[1],
|
||||
temporal_state_shape[2],
|
||||
),
|
||||
dtype=ssm_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.mamba_cache = (
|
||||
conv_state,
|
||||
temporal_state,
|
||||
mixed_qkv_cache,
|
||||
intermediate_ssm_state_cache,
|
||||
)
|
||||
else:
|
||||
self.mamba_cache = (conv_state, temporal_state)
|
||||
self.size = size
|
||||
self.free_slots = list(range(size))
|
||||
self.mem_usage = self.get_mamba_size() / GB
|
||||
logger.info(
|
||||
f"Mamba Cache is allocated. "
|
||||
f"conv_state size: {conv_state.numel() * conv_state.itemsize / GB:.2f}GB, "
|
||||
f"ssm_state size: {temporal_state.numel() * temporal_state.itemsize / GB:.2f}GB "
|
||||
)
|
||||
|
||||
def get_mamba_params_all_layers(self):
|
||||
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
||||
|
||||
def get_mamba_params(self, layer_id: int):
|
||||
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
||||
|
||||
def get_mamba_size(self):
|
||||
return (
|
||||
np.prod(self.mamba_cache[0].shape) * self.mamba_cache[0].dtype.itemsize
|
||||
+ np.prod(self.mamba_cache[1].shape) * self.mamba_cache[1].dtype.itemsize
|
||||
)
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int) -> Optional[List[int]]:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
return select_index
|
||||
|
||||
def free(self, free_index: Union[int, List[int]]):
|
||||
if isinstance(free_index, (int,)):
|
||||
self.free_slots.append(free_index)
|
||||
else:
|
||||
self.free_slots.extend(free_index)
|
||||
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
||||
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class HybridReqToTokenPool(ReqToTokenPool):
|
||||
"""A memory pool that maps a request to its token locations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
max_context_len: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
conv_dtype: torch.dtype,
|
||||
ssm_dtype: torch.dtype,
|
||||
mamba_layers: List[int],
|
||||
conv_state_shape: Tuple[int, int],
|
||||
temporal_state_shape: Tuple[int, int],
|
||||
speculative_num_draft_tokens: int,
|
||||
):
|
||||
super().__init__(
|
||||
size=size,
|
||||
max_context_len=max_context_len,
|
||||
device=device,
|
||||
enable_memory_saver=enable_memory_saver,
|
||||
)
|
||||
|
||||
self.mamba_pool = MambaPool(
|
||||
size,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
len(mamba_layers),
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
device,
|
||||
speculative_num_draft_tokens,
|
||||
)
|
||||
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
||||
|
||||
self.device = device
|
||||
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.empty(
|
||||
size, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
||||
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
||||
|
||||
# For chunk prefill req, we do not need to allocate mamba cache,
|
||||
# We could use allocated mamba cache instead.
|
||||
def alloc(
|
||||
self, need_size: int, reqs: Optional[List["Req"]] = None
|
||||
) -> Optional[List[int]]:
|
||||
select_index = super().alloc(need_size)
|
||||
if select_index == None:
|
||||
return None
|
||||
|
||||
mamba_index = []
|
||||
for req in reqs:
|
||||
rid = req.rid
|
||||
if rid in self.rid_to_mamba_index_mapping:
|
||||
mid = self.rid_to_mamba_index_mapping[rid]
|
||||
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
||||
mid = mid[0]
|
||||
self.rid_to_mamba_index_mapping[rid] = mid
|
||||
self.mamba_index_to_rid_mapping[mid] = rid
|
||||
mamba_index.append(mid)
|
||||
assert len(select_index) == len(
|
||||
mamba_index
|
||||
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
||||
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
||||
mamba_index, dtype=torch.int32, device=self.device
|
||||
)
|
||||
return select_index
|
||||
|
||||
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
||||
return self.req_index_to_mamba_index_mapping[req_indices]
|
||||
|
||||
def get_mamba_params(self, layer_id: int):
|
||||
assert layer_id in self.mamba_map
|
||||
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
||||
|
||||
def get_mamba_params_all_layers(self):
|
||||
return self.mamba_pool.get_mamba_params_all_layers()
|
||||
|
||||
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
||||
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
||||
super().free(free_index)
|
||||
if free_mamba_cache:
|
||||
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
||||
mamba_index_list = mamba_index.tolist()
|
||||
if isinstance(mamba_index_list, int):
|
||||
mamba_index_list = [mamba_index_list]
|
||||
self.mamba_pool.free(mamba_index_list)
|
||||
for mid in mamba_index_list:
|
||||
rid = self.mamba_index_to_rid_mapping[mid]
|
||||
self.mamba_index_to_rid_mapping.pop(mid)
|
||||
self.rid_to_mamba_index_mapping.pop(rid)
|
||||
|
||||
def clear(self):
|
||||
super().clear()
|
||||
self.mamba_pool.clear()
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
@@ -441,6 +639,88 @@ class MHATokenToKVPool(KVCache):
|
||||
)
|
||||
|
||||
|
||||
class HybridLinearKVPool(KVCache):
|
||||
"""KV cache with separate pools for full and linear attention layers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
dtype: torch.dtype,
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
full_attention_layer_ids: List[int],
|
||||
enable_kvcache_transpose: bool,
|
||||
device: str,
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.full_layer_nums = len(full_attention_layer_ids)
|
||||
self.page_size = 1
|
||||
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
||||
assert not enable_kvcache_transpose
|
||||
self.full_kv_pool = MHATokenToKVPool(
|
||||
size=size,
|
||||
page_size=self.page_size,
|
||||
dtype=dtype,
|
||||
head_num=head_num,
|
||||
head_dim=head_dim,
|
||||
layer_num=self.full_layer_nums,
|
||||
device=device,
|
||||
enable_memory_saver=False,
|
||||
)
|
||||
self.full_attention_layer_id_mapping = {
|
||||
id: i for i, id in enumerate(full_attention_layer_ids)
|
||||
}
|
||||
k_size, v_size = self.get_kv_size_bytes()
|
||||
self.mem_usage = (k_size + v_size) / GB
|
||||
|
||||
def get_kv_size_bytes(self):
|
||||
return self.full_kv_pool.get_kv_size_bytes()
|
||||
|
||||
def get_contiguous_buf_infos(self):
|
||||
return self.full_kv_pool.get_contiguous_buf_infos()
|
||||
|
||||
def _transfer_full_attention_id(self, layer_id: int):
|
||||
if layer_id not in self.full_attention_layer_id_mapping:
|
||||
raise ValueError(
|
||||
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
||||
)
|
||||
return self.full_attention_layer_id_mapping[layer_id]
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
layer_id = self._transfer_full_attention_id(layer_id)
|
||||
return self.full_kv_pool.get_key_buffer(layer_id)
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
layer_id = self._transfer_full_attention_id(layer_id)
|
||||
return self.full_kv_pool.get_value_buffer(layer_id)
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
layer_id = self._transfer_full_attention_id(layer_id)
|
||||
return self.full_kv_pool.get_kv_buffer(layer_id)
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
loc: torch.Tensor,
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: float = 1.0,
|
||||
v_scale: float = 1.0,
|
||||
):
|
||||
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
||||
self.full_kv_pool.set_kv_buffer(
|
||||
None,
|
||||
loc,
|
||||
cache_k,
|
||||
cache_v,
|
||||
k_scale,
|
||||
v_scale,
|
||||
layer_id_override=layer_id,
|
||||
)
|
||||
|
||||
|
||||
class SWAKVPool(KVCache):
|
||||
"""KV cache with separate pools for full and SWA attention layers."""
|
||||
|
||||
|
||||
@@ -85,6 +85,8 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
AscendMLAPagedTokenToKVPool,
|
||||
AscendTokenToKVPool,
|
||||
DoubleSparseTokenToKVPool,
|
||||
HybridLinearKVPool,
|
||||
HybridReqToTokenPool,
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
@@ -303,6 +305,26 @@ class ModelRunner:
|
||||
if architectures and not any("Llama4" in arch for arch in architectures):
|
||||
self.is_hybrid = self.model_config.is_hybrid = True
|
||||
|
||||
if self.is_hybrid_gdn:
|
||||
logger.warning("Hybrid GDN model detected, disable radix cache")
|
||||
self.server_args.disable_radix_cache = True
|
||||
self.server_args.attention_backend = "hybrid_linear_attn"
|
||||
if self.server_args.max_mamba_cache_size is None:
|
||||
if self.server_args.max_running_requests is not None:
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_running_requests
|
||||
)
|
||||
else:
|
||||
self.server_args.max_mamba_cache_size = 512
|
||||
self.server_args.max_mamba_cache_size = (
|
||||
self.server_args.max_mamba_cache_size
|
||||
// (
|
||||
self.server_args.dp_size
|
||||
if self.server_args.enable_dp_attention
|
||||
else 1
|
||||
)
|
||||
)
|
||||
|
||||
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
||||
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
||||
# determine the number of layers.
|
||||
@@ -1080,6 +1102,8 @@ class ModelRunner:
|
||||
"num_nextn_predict_layers",
|
||||
self.num_effective_layers,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
||||
else:
|
||||
num_layers = self.num_effective_layers
|
||||
if self.use_mla_backend:
|
||||
@@ -1099,9 +1123,22 @@ class ModelRunner:
|
||||
rest_memory = available_gpu_memory - total_gpu_memory * (
|
||||
1 - self.mem_fraction_static
|
||||
)
|
||||
if self.is_hybrid_gdn:
|
||||
rest_memory -= (
|
||||
self.server_args.max_mamba_cache_size
|
||||
* self.model_config.hf_config.mamba_cache_per_req
|
||||
/ (1 << 30)
|
||||
)
|
||||
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
||||
return max_num_token
|
||||
|
||||
@property
|
||||
def is_hybrid_gdn(self):
|
||||
return self.model_config.hf_config.architectures[0] in [
|
||||
"Qwen3NextForCausalLM",
|
||||
"Qwen3NextForCausalLMMTP",
|
||||
]
|
||||
|
||||
def set_num_token_hybrid(self):
|
||||
if (
|
||||
"Llama4ForConditionalGeneration"
|
||||
@@ -1222,6 +1259,8 @@ class ModelRunner:
|
||||
),
|
||||
4096,
|
||||
)
|
||||
if self.is_hybrid_gdn:
|
||||
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
@@ -1300,6 +1339,28 @@ class ModelRunner:
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
pre_alloc_size=pre_alloc_size,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
config = self.model_config.hf_config
|
||||
(
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
) = config.hybrid_gdn_params
|
||||
self.req_to_token_pool = HybridReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
max_context_len=self.model_config.context_len
|
||||
+ extra_max_context_len,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
conv_state_shape=conv_state_shape,
|
||||
temporal_state_shape=temporal_state_shape,
|
||||
conv_dtype=conv_dtype,
|
||||
ssm_dtype=ssm_dtype,
|
||||
mamba_layers=mamba_layers,
|
||||
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
else:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs,
|
||||
@@ -1382,6 +1443,23 @@ class ModelRunner:
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
)
|
||||
elif self.is_hybrid_gdn:
|
||||
self.token_to_kv_pool = HybridLinearKVPool(
|
||||
size=self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
head_num=self.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
),
|
||||
head_dim=self.model_config.head_dim,
|
||||
# if draft worker, we only need 1 attention layer's kv pool
|
||||
full_attention_layer_ids=(
|
||||
[0]
|
||||
if self.is_draft_worker
|
||||
else self.model_config.hf_config.full_attention_layer_ids
|
||||
),
|
||||
enable_kvcache_transpose=False,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
self.max_total_num_tokens,
|
||||
@@ -1615,6 +1693,24 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
return DualChunkFlashAttentionBackend(self)
|
||||
elif backend_str == "hybrid_linear_attn":
|
||||
assert (
|
||||
self.is_hybrid_gdn
|
||||
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
||||
from sglang.srt.layers.attention.flashattention_backend import (
|
||||
FlashAttentionBackend,
|
||||
)
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
MambaAttnBackend,
|
||||
)
|
||||
|
||||
full_attn_backend = FlashAttentionBackend(self)
|
||||
linear_attn_backend = MambaAttnBackend(self)
|
||||
full_attn_layers = self.model_config.hf_config.full_attention_layer_ids
|
||||
return HybridLinearAttnBackend(
|
||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid attention backend: {backend_str}")
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from tqdm.auto import tqdm
|
||||
from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_rank
|
||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||
from sglang.srt.utils import print_warning_once
|
||||
@@ -680,7 +681,7 @@ def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
||||
"""Create a weight loader that shards the weights along the given axis"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
tp_rank = get_attention_tp_rank()
|
||||
|
||||
shard_size = param.data.shape[shard_axis]
|
||||
start_idx = tp_rank * shard_size
|
||||
|
||||
1072
python/sglang/srt/models/qwen3_next.py
Normal file
1072
python/sglang/srt/models/qwen3_next.py
Normal file
File diff suppressed because it is too large
Load Diff
117
python/sglang/srt/models/qwen3_next_mtp.py
Normal file
117
python/sglang/srt/models/qwen3_next_mtp.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Inference-only Qwen3Next MTP Speculative Decoding."""
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
||||
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
# if not set, model load will be broken in Qwen3NextForCausalLM load_weights()
|
||||
self.pp_group = get_pp_group()
|
||||
# self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP")
|
||||
|
||||
# currently based on the provided ckpt, we:
|
||||
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
|
||||
# (2) hardcode bias=False since not provided
|
||||
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
||||
if getattr(
|
||||
config, "use_gemma_rms_norm", getattr(config, "apply_layernorm_1p", False)
|
||||
):
|
||||
logger.warning_once(
|
||||
"Using Gemma RMSNorm for input normalization and post attn normalization."
|
||||
)
|
||||
RMSNorm_cls = GemmaRMSNorm
|
||||
else:
|
||||
RMSNorm_cls = RMSNorm
|
||||
self.pre_fc_norm_embedding = RMSNorm_cls(
|
||||
config.hidden_size, config.rms_norm_eps
|
||||
)
|
||||
self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
|
||||
config.num_hidden_layers = 1
|
||||
config.full_attention_interval = 1
|
||||
self.model = Qwen3NextModel(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("model.shared_head.head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if input_embeds is None:
|
||||
input_embeds = self.model.embed_tokens(input_ids)
|
||||
|
||||
input_embeds = self.pre_fc_norm_embedding(input_embeds)
|
||||
hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states)
|
||||
hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
forward_batch,
|
||||
hidden_states,
|
||||
)
|
||||
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
|
||||
):
|
||||
super().load_weights(weights, is_mtp=True)
|
||||
|
||||
|
||||
EntryClass = [Qwen3NextForCausalLMMTP]
|
||||
@@ -95,6 +95,7 @@ ATTENTION_BACKEND_CHOICES = [
|
||||
"trtllm_mla",
|
||||
"trtllm_mha",
|
||||
"dual_chunk_flash_attn",
|
||||
"hybrid_linear_attn",
|
||||
# AMD specific
|
||||
"aiter",
|
||||
"wave",
|
||||
@@ -390,6 +391,10 @@ class ServerArgs:
|
||||
enable_pdmux: bool = False
|
||||
sm_group_num: int = 3
|
||||
|
||||
# Mamba cache
|
||||
max_mamba_cache_size: Optional[int] = None
|
||||
mamba_ssm_dtype: str = "float32"
|
||||
|
||||
# Deprecated arguments
|
||||
enable_ep_moe: bool = False
|
||||
enable_deepep_moe: bool = False
|
||||
@@ -835,6 +840,8 @@ class ServerArgs:
|
||||
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
||||
"1" if self.enable_torch_compile else "0"
|
||||
)
|
||||
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = self.mamba_ssm_dtype
|
||||
|
||||
# Set env var before grammar backends init
|
||||
os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
|
||||
"1" if self.disable_outlines_disk_cache else "0"
|
||||
@@ -1714,7 +1721,20 @@ class ServerArgs:
|
||||
default=ServerArgs.moe_dense_tp_size,
|
||||
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
|
||||
)
|
||||
|
||||
# Mamba Cache
|
||||
parser.add_argument(
|
||||
"--max-mamba-cache-size",
|
||||
type=int,
|
||||
default=ServerArgs.max_mamba_cache_size,
|
||||
help="It is used for mamba cache memory static allocation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mamba-ssm-dtype",
|
||||
type=str,
|
||||
default=ServerArgs.mamba_ssm_dtype,
|
||||
choices=["float32", "bfloat16"],
|
||||
help="It is used to tune mamba ssm dtype",
|
||||
)
|
||||
# Hierarchical cache
|
||||
parser.add_argument(
|
||||
"--enable-hierarchical-cache",
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
import bisect
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.attention.fla.fused_recurrent import (
|
||||
fused_recurrent_gated_delta_rule_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn
|
||||
from sglang.srt.model_executor.cuda_graph_runner import (
|
||||
CUDA_GRAPH_CAPTURE_FAILED_MSG,
|
||||
CudaGraphRunner,
|
||||
get_batch_sizes_to_capture,
|
||||
get_global_graph_memory_pool,
|
||||
model_capture_mode,
|
||||
set_global_graph_memory_pool,
|
||||
)
|
||||
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
|
||||
class MambaStateUpdateCudaGraphRunner:
|
||||
def __init__(self, eagle_worker: "EAGLEWorker"):
|
||||
self.eagle_worker = eagle_worker
|
||||
model_runner = eagle_worker.target_worker.model_runner
|
||||
self.model_runner = model_runner
|
||||
self.attn_backend = model_runner.attn_backend.attn_backend_list[1]
|
||||
self.req_to_token_pool = self.attn_backend.req_to_token_pool
|
||||
|
||||
self.graphs = {}
|
||||
self.output_buffers = {}
|
||||
self.graph_input_buffer = None
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.model = model_runner.model
|
||||
|
||||
self.enable_profile_cuda_graph = (
|
||||
model_runner.server_args.enable_profile_cuda_graph
|
||||
)
|
||||
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
||||
self.max_bs = self.capture_bs[-1]
|
||||
|
||||
self.init_cuda_graph_state()
|
||||
# Capture
|
||||
try:
|
||||
with model_capture_mode():
|
||||
self.capture()
|
||||
except RuntimeError as e:
|
||||
raise Exception(
|
||||
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self):
|
||||
self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache
|
||||
self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2]
|
||||
num_mamba_layers = self.mamba_cache[0].shape[0]
|
||||
conv_dtype = torch.bfloat16
|
||||
conv_shape = self.mamba_cache[0].shape[2]
|
||||
total_token_number = self.max_accepted_tokens * self.max_bs
|
||||
self.mixed_qkv_cache = torch.empty(
|
||||
size=(
|
||||
num_mamba_layers,
|
||||
total_token_number,
|
||||
conv_shape,
|
||||
),
|
||||
dtype=conv_dtype,
|
||||
device="cuda",
|
||||
)
|
||||
self.query_start_loc = torch.zeros(
|
||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.state_indices = torch.zeros(
|
||||
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
self.has_initial_states = torch.ones(
|
||||
self.max_bs, dtype=torch.bool, device="cuda"
|
||||
)
|
||||
|
||||
def capture(self):
|
||||
CudaGraphRunner.capture(self)
|
||||
|
||||
def capture_one_batch_size(self, bs: int, forward: Callable):
|
||||
"""
|
||||
Capture CUDA Graph for a typical workload
|
||||
"""
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
stream = self.stream
|
||||
total_token_number = bs * self.max_accepted_tokens
|
||||
mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number]
|
||||
|
||||
query_start_loc = self.query_start_loc[: bs + 1]
|
||||
state_indices = self.state_indices[:bs]
|
||||
has_initial_states = self.has_initial_states[:bs]
|
||||
|
||||
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
|
||||
conv_states = mamba_caches[0]
|
||||
mamba_map = self.req_to_token_pool.mamba_map
|
||||
|
||||
def run_once():
|
||||
for i in range(len(self.model.model.layers)):
|
||||
layer = self.model.model.layers[i]
|
||||
if not isinstance(layer, Qwen3HybridLinearDecoderLayer):
|
||||
continue
|
||||
conv_weights = layer.linear_attn.conv1d.weight.view(
|
||||
layer.linear_attn.conv1d.weight.size(0),
|
||||
layer.linear_attn.conv1d.weight.size(2),
|
||||
)
|
||||
layer_id = mamba_map[i]
|
||||
|
||||
causal_conv1d_fn(
|
||||
mixed_qkvs[layer_id].transpose(0, 1),
|
||||
conv_weights,
|
||||
layer.linear_attn.conv1d.bias,
|
||||
activation=layer.linear_attn.activation,
|
||||
conv_states=conv_states[layer_id],
|
||||
has_initial_state=has_initial_states,
|
||||
cache_indices=state_indices,
|
||||
query_start_loc=query_start_loc,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
for _ in range(2):
|
||||
torch.cuda.synchronize()
|
||||
self.model_runner.tp_group.barrier()
|
||||
|
||||
run_once()
|
||||
|
||||
with torch.cuda.graph(
|
||||
graph, pool=get_global_graph_memory_pool(), stream=stream
|
||||
):
|
||||
out = run_once()
|
||||
|
||||
set_global_graph_memory_pool(graph.pool())
|
||||
return graph, out
|
||||
|
||||
def can_run(self, accepted_length):
|
||||
bs = accepted_length.shape[0]
|
||||
return bs <= self.max_bs
|
||||
|
||||
def replay_repare(self, accepted_length):
|
||||
request_number = accepted_length.shape[0]
|
||||
# QQ: step = spec num_draft token num
|
||||
num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2]
|
||||
query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype)
|
||||
query_start_loc = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
1,
|
||||
dtype=query_start_loc.dtype,
|
||||
device=query_start_loc.device,
|
||||
),
|
||||
query_start_loc,
|
||||
]
|
||||
)
|
||||
mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze(
|
||||
0
|
||||
) < accepted_length.unsqueeze(1)
|
||||
|
||||
state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[
|
||||
:request_number
|
||||
]
|
||||
mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers()
|
||||
|
||||
_, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches
|
||||
mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask]
|
||||
self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs)
|
||||
self.query_start_loc[: request_number + 1] = query_start_loc
|
||||
self.query_start_loc[request_number + 1 :] = self.query_start_loc[
|
||||
request_number
|
||||
]
|
||||
self.state_indices[:request_number] = state_indices_tensor
|
||||
self.state_indices[request_number:] = -1
|
||||
valid_mask = accepted_length > 0
|
||||
if intermediate_state_cache is not None:
|
||||
last_steps = (accepted_length - 1).to(torch.int64)
|
||||
valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64)
|
||||
|
||||
ssm_states[:, valid_state_indices, :] = intermediate_state_cache[
|
||||
:, valid_state_indices, last_steps
|
||||
].to(ssm_states.dtype)
|
||||
|
||||
def replay(self, accepted_length):
|
||||
# batch_size and num_seqs can be different in case there are finished examples
|
||||
# in the batch, which will not be counted as num_seqs
|
||||
raw_bs = accepted_length.shape[0]
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
|
||||
bs = self.capture_bs[index]
|
||||
|
||||
self.replay_repare(accepted_length)
|
||||
# Replay
|
||||
self.graphs[bs].replay()
|
||||
@@ -214,6 +214,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
"triton": self._create_triton_decode_backend,
|
||||
"aiter": self._create_aiter_decode_backend,
|
||||
"fa3": self._create_fa3_decode_backend,
|
||||
"hybrid_linear_attn": self._create_fa3_decode_backend,
|
||||
"flashmla": self._create_flashmla_decode_backend,
|
||||
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
||||
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
||||
@@ -231,6 +232,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
"triton": self._create_triton_prefill_backend,
|
||||
"aiter": self._create_aiter_prefill_backend,
|
||||
"fa3": self._create_fa3_prefill_backend,
|
||||
"hybrid_linear_attn": self._create_fa3_prefill_backend,
|
||||
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||
}
|
||||
@@ -405,6 +407,15 @@ class EAGLEWorker(TpModelWorker):
|
||||
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
||||
)
|
||||
|
||||
if self.target_worker.model_runner.is_hybrid_gdn:
|
||||
from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import (
|
||||
MambaStateUpdateCudaGraphRunner,
|
||||
)
|
||||
|
||||
self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner(
|
||||
self
|
||||
)
|
||||
|
||||
@property
|
||||
def draft_model_runner(self):
|
||||
return self.model_runner
|
||||
@@ -826,6 +837,24 @@ class EAGLEWorker(TpModelWorker):
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
||||
|
||||
# QQ: can be optimized
|
||||
if self.target_worker.model_runner.is_hybrid_gdn:
|
||||
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
||||
accepted_length = (
|
||||
torch.tensor(
|
||||
res.accept_length_per_req_cpu,
|
||||
device=logits_output.hidden_states.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
+ 1
|
||||
)
|
||||
if self.cuda_graph_runner_for_target_verify.can_run(accepted_length):
|
||||
self.cuda_graph_runner_for_target_verify.replay(accepted_length)
|
||||
else:
|
||||
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
||||
accepted_length, self.target_worker.model_runner.model
|
||||
)
|
||||
|
||||
if batch.return_logprob:
|
||||
self.add_logprob_values(batch, res, logits_output)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user