From 083629c23564e1a64deaa052f1df5c5d914358d8 Mon Sep 17 00:00:00 2001 From: ilyasch2 <104485953+ilyasch2@users.noreply.github.com> Date: Thu, 2 Oct 2025 15:15:36 +0400 Subject: [PATCH] [model] Add mamba2 and Falcon-H1 support. (#10988) Co-authored-by: Younes Belkada Co-authored-by: Younes B <49240599+younesbelkada@users.noreply.github.com> --- python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/falcon_h1.py | 360 +++++++++ python/sglang/srt/hf_transformers_utils.py | 2 + .../attention/hybrid_linear_attn_backend.py | 1 + .../srt/layers/attention/mamba/mamba.py | 567 ++++++++++++- .../srt/layers/attention/mamba/mamba_utils.py | 81 ++ .../layers/attention/mamba/ops/__init__.py | 2 + .../attention/mamba/ops/layernorm_gated.py | 172 ++++ .../layers/attention/mamba/ops/mamba_ssm.py | 442 ++++++++++ .../srt/layers/attention/mamba/ops/ssd_bmm.py | 264 ++++++ .../attention/mamba/ops/ssd_chunk_scan.py | 622 ++++++++++++++ .../attention/mamba/ops/ssd_chunk_state.py | 757 ++++++++++++++++++ .../attention/mamba/ops/ssd_combined.py | 262 ++++++ .../attention/mamba/ops/ssd_state_passing.py | 275 +++++++ .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/falcon_h1.py | 576 +++++++++++++ test/srt/models/test_falcon_h1_models.py | 147 ++++ test/srt/run_suite.py | 1 + 18 files changed, 4533 insertions(+), 1 deletion(-) create mode 100644 python/sglang/srt/configs/falcon_h1.py create mode 100644 python/sglang/srt/layers/attention/mamba/mamba_utils.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/__init__.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py create mode 100644 python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py create mode 100644 python/sglang/srt/models/falcon_h1.py create mode 100644 test/srt/models/test_falcon_h1_models.py diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 7d285b3d3..8a8a3bdeb 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -4,6 +4,7 @@ from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config from sglang.srt.configs.dots_ocr import DotsOCRConfig from sglang.srt.configs.dots_vlm import DotsVLMConfig from sglang.srt.configs.exaone import ExaoneConfig +from sglang.srt.configs.falcon_h1 import FalconH1Config from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig @@ -30,4 +31,5 @@ __all__ = [ "Qwen3NextConfig", "DotsVLMConfig", "DotsOCRConfig", + "FalconH1Config", ] diff --git a/python/sglang/srt/configs/falcon_h1.py b/python/sglang/srt/configs/falcon_h1.py new file mode 100644 index 000000000..368404bd0 --- /dev/null +++ b/python/sglang/srt/configs/falcon_h1.py @@ -0,0 +1,360 @@ +# coding=utf-8 +# Copyright 2024 TII and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon-H1 model configuration""" + +import enum +import os + +import numpy as np +import torch +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +from sglang.srt.distributed.utils import divide +from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator +from sglang.srt.layers.dp_attention import ( + get_attention_tp_size, + get_tensor_model_parallel_world_size, +) + +logger = logging.get_logger(__name__) + + +class FalconH1Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a + FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration + with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf). + The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU. + The checkpoints are jointly trained by IBM, Princeton, and UIUC. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 128000): + Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`FalconH1Model`] + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the + model has a output word embedding layer. + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 14336): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + num_logits_to_keep (`int` or `None`, *optional*, defaults to 1): + Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an + integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the + logits of the last prompt token are needed for generation. For long sequences, the logits for the entire + sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint + significantly. + pad_token_id (`int`, *optional*, defaults to 0): + The id of the padding token. + bos_token_id (`int`, *optional*, defaults to 1): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 2): + The id of the "end-of-sequence" token. + max_position_embeddings (`int`, *optional*, defaults to 8192): + Max cached sequence length for the model + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mamba_d_ssm (`int`, *optional*, defaults to 1024): + The dimension of the SSM state space latents. + mamba_n_heads (`int`, *optional*, defaults to 128): + The number of mamba heads used in the v2 implementation. + mamba_d_head (`int`, *optional*, defaults to `"auto"`): + Head embedding dimension size + mamba_n_groups (`int`, *optional*, defaults to 1): + The number of the mamba groups used in the v2 implementation. + mamba_d_state (`int`, *optional*, defaults to 256): + The dimension the mamba state space latents + mamba_d_conv (`int`, *optional*, defaults to 4): + The size of the mamba convolution kernel + mamba_expand (`int`, *optional*, defaults to 2): + Expanding factor (relative to hidden_size) used to determine the mamba intermediate size + mamba_chunk_size (`int`, *optional*, defaults to 256): + The chunks in which to break the sequence when doing prefill/training + mamba_conv_bias (`bool`, *optional*, defaults to `True`): + Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block. + mamba_proj_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block + mamba_norm_before_gate (`bool`, *optional*, defaults to `True`): + Whether to use RMSNorm before the gate in the Mamba block + mamba_rms_norm (`bool`, *optional*, defaults to `False`): + Whether to use RMSNorm instead of LayerNorm in the Mamba block + projectors_bias (`bool`, *optional*, defaults to `False`): + Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block + rope_theta (`float`, *optional*, defaults to 100000.0): + The theta value used for the RoPE embeddings. + rope_scaling (`float`, *optional*): + The scaling value used for the RoPE embeddings. If `None`, no scaling is applied. + lm_head_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the LM head. This is used to scale the output of the LM head. + embedding_multiplier (`float`, *optional*, defaults to 1.0): + The multiplier for the embedding layer. This is used to scale the output of the embedding layer. + mlp_multipliers (`list[float]`, *optional*): + The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is + the multiplier of gate layer, the second value is the multiplier of the down_proj layer. + key_multiplier (`float`, *optional*): + The multiplier for the key layer. This is used to scale the output of the key layer. + attention_out_multiplier (`float`, *optional*): + The multiplier for the attention output layer. This is used to scale the output of the attention output + attention_in_multiplier (`float`, *optional*): + The multiplier for the attention input layer. This is used to scale the output of the attention input layer. + ssm_multipliers (`list[float]`, *optional*): + The multipliers for the SSM layers. This is used to scale the output of the SSM layers. + ssm_in_multiplier (`float`, *optional*): + The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer. + ssm_out_multiplier (`float`, *optional*): + The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer. + """ + + model_type = "falcon_h1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=128000, + tie_word_embeddings=False, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + num_logits_to_keep=1, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + max_position_embeddings=8192, + attention_dropout=0.0, + mamba_d_ssm=1024, + mamba_n_heads=128, + mamba_d_head="auto", + mamba_n_groups=1, + mamba_d_state=256, + mamba_d_conv=4, + mamba_expand=2, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=False, + mamba_norm_before_gate=True, + mamba_rms_norm=False, + projectors_bias=False, + rope_theta=100000.0, + rope_scaling=None, + lm_head_multiplier=1.0, + embedding_multiplier=1.0, + mlp_multipliers=None, + key_multiplier=None, + attention_out_multiplier=None, + attention_in_multiplier=None, + ssm_multipliers=None, + ssm_in_multiplier=None, + ssm_out_multiplier=None, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.attention_dropout = attention_dropout + self.attention_bias = False + self.mlp_bias = False + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + + self.use_cache = use_cache + self.num_logits_to_keep = num_logits_to_keep + + self.rope_theta = rope_theta + self.rope_scaling = None + self.rope_scaling = rope_scaling + self.projectors_bias = projectors_bias + mamba_intermediate = ( + mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm + ) + + if mamba_intermediate % mamba_n_heads != 0: + raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size") + + # for the mamba_v2, must satisfy the following + if mamba_d_head == "auto": + mamba_d_head = mamba_intermediate // mamba_n_heads + + if mamba_d_head * mamba_n_heads != mamba_intermediate: + raise ValueError( + "The dimensions for the Mamba head state do not match the model intermediate_size" + ) + + self.mamba_d_ssm = mamba_d_ssm + self.mamba_n_heads = mamba_n_heads + self.mamba_d_head = mamba_d_head + self.mamba_n_groups = mamba_n_groups + self.mamba_d_state = mamba_d_state + self.mamba_d_conv = mamba_d_conv + self.mamba_expand = mamba_expand + self.mamba_chunk_size = mamba_chunk_size + self.mamba_conv_bias = mamba_conv_bias + self.mamba_proj_bias = mamba_proj_bias + + self.mamba_norm_before_gate = mamba_norm_before_gate + self.mamba_rms_norm = mamba_rms_norm + + self.lm_head_multiplier = lm_head_multiplier + self.embedding_multiplier = embedding_multiplier + + if mlp_multipliers is not None: + self.mlp_multipliers = mlp_multipliers + else: + self.mlp_multipliers = [1.0, 1.0] + + if attention_out_multiplier is not None: + self.attention_out_multiplier = attention_out_multiplier + else: + self.attention_out_multiplier = 1.0 + + if attention_in_multiplier is not None: + self.attention_in_multiplier = attention_in_multiplier + else: + self.attention_in_multiplier = 1.0 + + if key_multiplier is not None: + self.key_multiplier = key_multiplier + else: + self.key_multiplier = 1.0 + + if ssm_multipliers is not None: + self.ssm_multipliers = ssm_multipliers + else: + self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0] + + if ssm_in_multiplier is not None: + self.ssm_in_multiplier = ssm_in_multiplier + else: + self.ssm_in_multiplier = 1.0 + + if ssm_out_multiplier is not None: + self.ssm_out_multiplier = ssm_out_multiplier + else: + self.ssm_out_multiplier = 1.0 + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + @property + def layers_block_type(self): + return ["falcon_h1" for i in range(self.num_hidden_layers)] + + @property + def mamba_cache_per_req(self): + conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = ( + self.hybrid_gdn_params + ) + mamba_layers_len = len(mamba_layers) + + return ( + int(np.prod(conv_state_shape)) * conv_dtype.itemsize + + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize + ) * mamba_layers_len + + @property + def full_attention_layer_ids(self): + # For Falcon-H1, we do have attention on all layers + return range(self.num_hidden_layers) + + @property + def linear_layer_ids(self): + # For Falcon-H1, we do have mamba on all layers + return range(self.num_hidden_layers) + + @property + def hybrid_gdn_params(self): + world_size = get_tensor_model_parallel_world_size() + + n_groups = self.mamba_n_groups + if self.mamba_n_groups % world_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards( + self.mamba_n_groups, world_size + ) + n_groups += extra_groups + + conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state + + conv_state_shape = ( + divide(conv_dim, world_size), + self.mamba_d_conv - 1, + ) + + # we TP-ize on the heads dimension + temporal_state_shape = ( + self.mamba_d_state, + self.mamba_d_head, + divide(self.mamba_n_heads, world_size), + ) + conv_dtype = torch.bfloat16 + dtype_map = { + "float32": torch.float32, + "bfloat16": torch.bfloat16, + } + ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]] + mamba_layers = self.linear_layer_ids + + return ( + conv_state_shape, + temporal_state_shape, + conv_dtype, + ssm_dtype, + mamba_layers, + ) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index b97af68a1..68b0c4534 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -41,6 +41,7 @@ from sglang.srt.configs import ( DotsOCRConfig, DotsVLMConfig, ExaoneConfig, + FalconH1Config, KimiVLConfig, LongcatFlashConfig, MultiModalityConfig, @@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { Step3VLConfig.model_type: Step3VLConfig, LongcatFlashConfig.model_type: LongcatFlashConfig, Qwen3NextConfig.model_type: Qwen3NextConfig, + FalconH1Config.model_type: FalconH1Config, DotsVLMConfig.model_type: DotsVLMConfig, DotsOCRConfig.model_type: DotsOCRConfig, } diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 435844f74..d405713b7 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -69,6 +69,7 @@ class MambaAttnBackend(AttentionBackend): def init_forward_metadata(self, forward_batch: ForwardBatch): bs = forward_batch.batch_size + if forward_batch.forward_mode.is_decode_or_idle(): query_start_loc = torch.arange( 0, bs + 1, dtype=torch.int32, device=self.device diff --git a/python/sglang/srt/layers/attention/mamba/mamba.py b/python/sglang/srt/layers/attention/mamba/mamba.py index 045a04048..b48ee694f 100644 --- a/python/sglang/srt/layers/attention/mamba/mamba.py +++ b/python/sglang/srt/layers/attention/mamba/mamba.py @@ -1,6 +1,39 @@ -from typing import Callable, List, Tuple +from typing import Callable, List, Optional, Tuple, Union import torch +import torch.nn as nn + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.custom_op import CustomOp +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.utils import divide +from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn +from sglang.srt.layers.attention.mamba.causal_conv1d import ( + causal_conv1d_fn, + causal_conv1d_update, +) +from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator +from sglang.srt.layers.attention.mamba.ops import ( + mamba_chunk_scan_combined, + selective_state_update, +) +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import ( + composed_weight_loader, + sharded_weight_loader, +) +from sglang.srt.utils import set_weight_attrs LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None] @@ -62,3 +95,535 @@ def mamba_v2_sharded_weight_loader( loaded_boundary += full_dim - extra return loader + + +class Mixer2RMSNormGated(CustomOp): + + def __init__( + self, + full_hidden_size: int, + full_n_groups: int, + use_rms_norm: bool = True, + eps: float = 1e-6, + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.full_hidden_size = full_hidden_size + self.group_size = full_hidden_size // full_n_groups + self.per_rank_hidden_size = full_hidden_size // self.tp_size + self.n_groups = full_hidden_size // self.group_size + + self.variance_epsilon = eps + self.use_rms_norm = use_rms_norm + if self.use_rms_norm: + # Register norm weight only if we're actually applying RMSNorm + self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)}) + else: + # Avoid checkpoint mismatch by skipping unused parameter + self.register_parameter("weight", None) + assert ( + self.full_hidden_size % self.tp_size == 0 + ), "Tensor parallel world size must divide hidden size." + + def forward_native( + self, + x: torch.Tensor, + gate: torch.Tensor, + ): + # Three tensor-parallel cases: + # 1. n_groups is 1 + # In this case we parallelize along the reduction dim. + # Each rank computes a local sum of squares followed by AllReduce + # 2. tp_size divides n_groups + # Each rank only reduces within its local group(s). + # No collective ops necessary. + # 3. The general case can be pretty complicated so we AllGather + # the input and then redundantly compute the RMSNorm. + input_dtype = x.dtype + x = x * nn.functional.silu(gate.to(torch.float32)) + if not self.use_rms_norm: + return x.to(input_dtype) + + if self.n_groups == 1: + if self.tp_size > 1: + # Compute local sum and then reduce to obtain global sum + local_sums = x.pow(2).sum(dim=-1, keepdim=True) + global_sums = tensor_model_parallel_all_reduce(local_sums) + # Calculate the variance + count = self.tp_size * x.shape[-1] + variance = global_sums / count + + else: + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + else: + redundant_tp: bool = self.n_groups % self.tp_size != 0 + if redundant_tp: + # To handle the general case, redundantly apply the variance + x = tensor_model_parallel_all_gather(x, -1) + + *prefix_dims, hidden_dim = x.shape + group_count = hidden_dim // self.group_size + x_grouped = x.view(*prefix_dims, group_count, self.group_size) + variance = x_grouped.pow(2).mean(-1, keepdim=True) + x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon) + x = x_grouped.view(*prefix_dims, hidden_dim) + + if redundant_tp: + start = self.per_rank_hidden_size * self.tp_rank + end = start + self.per_rank_hidden_size + x = x[..., start:end] + + return self.weight * x.to(input_dtype) + + def forward_cuda( + self, + x: torch.Tensor, + gate: torch.Tensor, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + input_dtype = x.dtype + if not self.use_rms_norm: + # Keep gate in float32 for numerical stability during silu + return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype) + + if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1: + return self.forward_native(x, gate) + + return layernorm_fn( + x, + self.weight.data, + bias=None, + z=gate, + eps=self.variance_epsilon, + norm_before_gate=False, + ) + + +class MambaMixer2(torch.nn.Module): + """ + Compute ∆, A, B, C, and D the state space parameters and compute + the `contextualized_states`. A, D are input independent + (see Mamba paper [1] Section 3.5.2 "Interpretation of A" + for why A isn't selective) ∆, B, C are input-dependent + (this is a key difference between Mamba and the linear time + invariant S4, and is why Mamba is called + **selective** state spaces) + """ + + def __init__( + self, + hidden_size: int, + ssm_state_size: int, + conv_kernel_size: int, + intermediate_size: int, + use_conv_bias: bool, + use_bias: bool, + chunk_size: int, + layer_id: int, + n_groups: int = 1, + num_heads: int = 128, + head_dim: int = 64, + rms_norm_eps: float = 1e-5, + activation: str = "silu", + use_rms_norm: bool = True, + model_config: Optional[ModelConfig] = None, + # cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + # For TP, the sharding plan is as follows: + # - for the conv modules, since + # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size, + # we shard intermediate_size and n_groups + # - since intermediate_size = n_heads * head_dim, sharding on + # intermediate_size is achieved by sharding on n_heads. + # - IF, world_size divides groups, then sharding + # (n_groups / world_size, n_heads / world_size) + # also maintains the invariant n_heads % n_groups == 0 + # - HOWEVER IF, world_size DOES NOT divide groups, then we need + # to allocate extra space in the shard, such that groups + # may be replicated to follow the head shard. + # - NOTE: currently for the world size DOES NOT divide groups + # case, we only support the case when n_groups == 1 + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + assert ( + num_heads % self.tp_size == 0 + ), "Tensor parallel world size must divide num heads." + + assert (n_groups % self.tp_size) == 0 or n_groups == 1, ( + "If tensor parallel world size does not divide num_groups, " + "then num_groups must equal 1." + ) + + self.ssm_state_size = ssm_state_size + self.conv_kernel_size = conv_kernel_size + self.activation = activation + self.layer_id = layer_id + + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_heads = num_heads + self.chunk_size = chunk_size + + self.n_groups = n_groups + if n_groups % self.tp_size != 0: + # - for TP we shard conv_dim by sharding on n_groups, + # - but if n_groups cannot divide tp_size, we need to + # extend some extra groups + groups = MambaStateShapeCalculator.extra_groups_for_head_shards( + n_groups, self.tp_size + ) + self.n_groups = n_groups + groups + + self.groups_ssm_state_size = self.n_groups * self.ssm_state_size + self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size + + self.conv1d = MergedColumnParallelLinear( + input_size=conv_kernel_size, + output_sizes=[ + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + ], + bias=use_conv_bias, + quant_config=None, + prefix=f"{prefix}.conv1d", + ) + + self.in_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[ + intermediate_size, + intermediate_size, + self.groups_ssm_state_size, + self.groups_ssm_state_size, + self.num_heads, + ], + bias=use_bias, + prefix=f"{prefix}.in_proj", + ) + if n_groups % self.tp_size != 0: + # This is the n_groups == 1 case, + # where we need to duplicate groups if TP>1. + + # - because in_proj is a concatenation of 3 weights, we + # need to interleave them before sharding + # - use the custom weight loader mamba_v2_sharded_weight_loader + # for conv1d.bias, covn1d.weight and in_proj.weight + # - need to set these settings, to assign the groups + # to the head shards + group_shard_settings = ( + self.groups_ssm_state_size, # expected model size + (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned + n_groups == 1, # if there was only one group + ) + intermediate_settings = (intermediate_size, 0, False) + head_settings = (self.num_heads, 0, False) + + # - the weight already has a "weight_loader" attribute + # which set_weight_attrs will raise if we do not + # delete before trying to override it + # - ditto for the other two weights below + delattr(self.conv1d.bias, "weight_loader") + set_weight_attrs( + self.conv1d.bias, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + delattr(self.conv1d.weight, "weight_loader") + set_weight_attrs( + self.conv1d.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, + group_shard_settings, + group_shard_settings, + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + if quant_config is None: + # - quant layers do not have a weight loader + delattr(self.in_proj.weight, "weight_loader") + set_weight_attrs( + self.in_proj.weight, + { + "weight_loader": mamba_v2_sharded_weight_loader( + [ + intermediate_settings, # for gate + intermediate_settings, + group_shard_settings, + group_shard_settings, + head_settings, # for dt + ], + self.tp_size, + self.tp_rank, + ) + }, + ) + + # unsqueeze to fit conv1d weights shape into the linear weights shape. + # Can't do this in `weight_loader` since it already exists in + # `ColumnParallelLinear` and `MergedColumnParallelLinear`, + # and `set_weight_attrs` doesn't allow to override it + self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) + + # - these are TPed by heads to reduce the size of the + # temporal shape + self.A = nn.Parameter( + torch.empty( + divide(num_heads, self.tp_size), + dtype=torch.float32, + ) + ) + self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) + self.use_rms_norm = use_rms_norm + + set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) + a_weight_loader = composed_weight_loader( + sharded_weight_loader(0), lambda x: -torch.exp(x.float()) + ) + set_weight_attrs(self.A, {"weight_loader": a_weight_loader}) + set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)}) + + self.out_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=use_bias, + input_is_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + reduce_results=False, + ) + + self.norm = Mixer2RMSNormGated( + intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps + ) + + # The tuple is (conv_state, ssm_state) + self.kv_cache = (torch.tensor([]), torch.tensor([])) + + self.model_config = model_config + self.prefix = prefix + + def forward_native( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + mup_vector: Optional[torch.Tensor] = None, + ): + pass + + def forward( + self, + hidden_states: torch.Tensor, + output: torch.Tensor, + forward_batch: ForwardBatch, + mup_vector: Optional[torch.Tensor] = None, + ): + # attn_backend_list[-1] gives access to MambaAttnBackend + mamba_backend = forward_batch.attn_backend.attn_backend_list[-1] + attn_metadata = mamba_backend.forward_metadata + state_indices_tensor = attn_metadata.mamba_cache_indices + chunk_size = self.chunk_size + + conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params( + self.layer_id + ) + + assert ( + ssm_state.size(1) == self.ssm_state_size + ), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}" + + query_start_loc = attn_metadata.query_start_loc + + chunk_size = self.chunk_size + + # TODO: properly support this + prep_initial_states = False + + # 1. Gated MLP's linear projection + projected_states, _ = self.in_proj(hidden_states) + + if mup_vector is not None: + projected_states = projected_states * mup_vector + + gate, hidden_states_B_C, dt = torch.split( + projected_states, + [ + self.intermediate_size // self.tp_size, + self.conv_dim // self.tp_size, + self.num_heads // self.tp_size, + ], + dim=-1, + ) + conv_weights = self.conv1d.weight.view( + self.conv1d.weight.size(0), self.conv1d.weight.size(2) + ) + + # - get hidden_states, B and C after depthwise convolution. + split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split( + hidden_states_B_C, + [ + self.intermediate_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + self.groups_ssm_state_size // self.tp_size, + ], + dim=-1, + ) + + preallocated_ssm_out = torch.empty( + [ + projected_states.shape[0], + (self.num_heads * self.head_dim) // self.tp_size, + ], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + + # Process prefill requests + if forward_batch.forward_mode.is_extend(): + # 2. Convolution sequence transformation + # - "cache_indices" updates the conv_state cache in positions + # pointed to by "state_indices_tensor" + num_prefill_tokens = forward_batch.extend_num_tokens or 0 + has_initial_states = forward_batch.extend_prefix_lens > 0 + cache_indices = attn_metadata.mamba_cache_indices + + x = hidden_states_B_C.transpose( + 0, 1 + ) # this is the form that causal-conv see + hidden_states_B_C = causal_conv1d_fn( + x, + conv_weights, + self.conv1d.bias, + activation=self.activation, + conv_states=conv_state, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + seq_lens_cpu=forward_batch.extend_seq_lens_cpu, + ).transpose(0, 1) + + hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C) + + # 3. State Space Model sequence transformation + initial_states = None + + if has_initial_states is not None and prep_initial_states: + initial_states = torch.where( + has_initial_states[:, None, None, None], + ssm_state[state_indices_tensor], + 0, + ) + + # NOTE: final output is an in-place update of out tensor + varlen_state = mamba_chunk_scan_combined( + hidden_states.view( + 1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim + ), + dt.unsqueeze(0), + self.A, + B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), + C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1), + chunk_size=chunk_size, + D=self.D, + z=None, + dt_bias=self.dt_bias, + cu_seqlens=query_start_loc, + initial_states=initial_states, + return_varlen_states=True, + return_final_states=False, + dt_softplus=True, + dt_limit=(0.0, float("inf")), + out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim), + state_dtype=ssm_state.dtype, + ) + + # update ssm states + # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor + ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1) + elif forward_batch.forward_mode.is_decode(): + num_decodes = len(query_start_loc) - 1 + # 2. Convolution sequence transformation + hidden_states_B_C = causal_conv1d_update( + hidden_states_B_C, + conv_state, + conv_weights, + self.conv1d.bias, + self.activation, + conv_state_indices=state_indices_tensor, + ) + + hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C) + + # 3. State Space Model sequence transformation + n_groups = self.n_groups // self.tp_size + A = ( + self.A[:, None, ...][:, :, None] + .expand(-1, self.head_dim, self.ssm_state_size) + .to(dtype=torch.float32) + ) + dt = dt[:, :, None].expand(-1, -1, self.head_dim) + dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim) + D = self.D[:, None, ...].expand(-1, self.head_dim) + B = B.view(-1, n_groups, B.shape[1] // n_groups) + C = C.view(-1, n_groups, C.shape[1] // n_groups) + hidden_states = hidden_states.view( + -1, self.num_heads // self.tp_size, self.head_dim + ) + + # - the hidden is reshaped into (bs, num_heads, head_dim) + # - mamba_cache_params.ssm_state's slots will be selected + # using state_indices_tensor_d + # NOTE: final output is an in-place update of out tensor + selective_state_update( + ssm_state.permute(0, 3, 2, 1), + hidden_states, + dt, + A, + B, + C, + D, + z=None, + dt_bias=dt_bias, + dt_softplus=True, + state_batch_indices=state_indices_tensor, + out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim), + ) + elif forward_batch.forward_mode.is_idle(): + preallocated_ssm_out = preallocated_ssm_out + + # 4. gated MLP + # GatedRMSNorm internally applying SiLU to the gate + # SiLU is applied internally before normalization, unlike standard + # norm usage + hidden_states = self.norm(preallocated_ssm_out, gate) + + # 5. Final linear projection + output[:], _ = self.out_proj(hidden_states) + + @property + def mamba_type(self) -> str: + return "mamba2" diff --git a/python/sglang/srt/layers/attention/mamba/mamba_utils.py b/python/sglang/srt/layers/attention/mamba/mamba_utils.py new file mode 100644 index 000000000..7672934be --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/mamba_utils.py @@ -0,0 +1,81 @@ +# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py +from sglang.srt.distributed.utils import divide + + +class MambaStateShapeCalculator: + + @classmethod + def linear_attention_state_shape( + cls, + num_heads: int, + tp_size: int, + head_dim: int, + ) -> tuple[tuple[int, int, int], ...]: + + state_shape = (num_heads // tp_size, head_dim, head_dim) + return (state_shape,) + + @classmethod + def mamba1_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + state_size: int, + conv_kernel: int, + ) -> tuple[tuple[int, int], tuple[int, int]]: + conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1) + + temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size) + + conv_state_shape = conv_state_shape[1], conv_state_shape[0] + + return conv_state_shape, temporal_state_shape + + @classmethod + def mamba2_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + n_groups: int, + num_heads: int, + head_dim: int, + state_size: int, + conv_kernel: int, + ) -> tuple[tuple[int, int], tuple[int, int, int]]: + # if n_groups is not divisible by world_size, need to extend the shards + # to ensure all groups needed by a head is sharded along with it + n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size) + # heads and n_groups are TP-ed + conv_dim = intermediate_size + 2 * n_groups * state_size + + # contiguous along 'dim' axis + conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size)) + + # These are not TP-ed as they depend on A, dt_bias, D + # - they are typically small + # e.g., (h_heads, head_dim, state_size) = (128, 64, 128) + temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size) + return conv_state_shape, temporal_state_shape + + @classmethod + def short_conv_state_shape( + cls, + tp_world_size: int, + intermediate_size: int, + conv_kernel: int, + ) -> tuple[tuple[int, int]]: + conv_dim = divide(intermediate_size, tp_world_size) + conv_state_shape = (conv_kernel - 1, conv_dim) + return (conv_state_shape,) + + @classmethod + def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int): + """Compute the increase in group numbers to account for + replication in order to accompany the head shards.""" + + # in the case ngoups % tp_size == 0, this will be zero + if ngroups % tp_size == 0: + return 0 + + # for n_groups == 1, this is exactly tp_size - n_groups + return tp_size - ngroups diff --git a/python/sglang/srt/layers/attention/mamba/ops/__init__.py b/python/sglang/srt/layers/attention/mamba/ops/__init__.py new file mode 100644 index 000000000..809ff36fb --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/__init__.py @@ -0,0 +1,2 @@ +from .mamba_ssm import selective_state_update +from .ssd_combined import mamba_chunk_scan_combined diff --git a/python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py b/python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py new file mode 100644 index 000000000..88b27eb5d --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py + +import torch +import triton +import triton.language as tl + + +@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None}) +@triton.jit +def _layer_norm_fwd_1pass_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + stride_x_row: tl.int64, + stride_y_row: tl.int64, + stride_z_row: tl.int64, + M: tl.int64, # number of rows in X + N: tl.int64, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_Z: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, + IS_RMS_NORM: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + if HAS_Z: + Z += row * stride_z_row + group * N + if not IS_RMS_NORM: + Mean += group * M + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute mean and variance + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + if not IS_RMS_NORM: + mean = tl.sum(x, axis=0) / N + tl.store(Mean + row, mean) + xbar = tl.where(cols < N, x - mean, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + else: + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if HAS_Z and NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _layer_norm_fwd( + x, + weight, + bias, + eps, + z=None, + out=None, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + if z is not None: + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + mean = ( + torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + if not is_rms_norm + else None + ) + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + grid = (M, ngroups) + with torch.cuda.device(x.device.index): + _layer_norm_fwd_1pass_kernel[grid]( + x, + out, + weight, + bias, + z, + mean, + rstd, + x.stride(0), + out.stride(0), + z.stride(0) if z is not None else 0, + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + IS_RMS_NORM=is_rms_norm, + num_warps=num_warps, + ) + return out, mean, rstd + + +def rms_norm_gated( + x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True +): + x_shape_og = x.shape + # reshape input data into 2D tensor + x = x.reshape(-1, x.shape[-1]) + if x.stride(-1) != 1: + x = x.contiguous() + if z is not None: + assert z.shape == x_shape_og + z = z.reshape(-1, z.shape[-1]) + if z.stride(-1) != 1: + z = z.contiguous() + weight = weight.contiguous() + if bias is not None: + bias = bias.contiguous() + y, _, _ = _layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + group_size=group_size, + norm_before_gate=norm_before_gate, + is_rms_norm=True, + ) + + return y.reshape(x_shape_og) diff --git a/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py b/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py new file mode 100644 index 000000000..69a1ff9fb --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py @@ -0,0 +1,442 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py + +import torch +import triton +import triton.language as tl +from packaging import version + +from sglang.srt import _custom_ops as ops + +PAD_SLOT_ID = -1 + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt + +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + { + "HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] + is not None + } +) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])} +) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + pad_slot_id, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + # If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate + # is taken from the state_batch_indices_ptr Otherwise, the state coordinate + # is the same as the batch id. + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64) + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + ( + offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate + ) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + state = tl.load(state_ptrs, mask=mask, other=0.0) + + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load( + A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + + mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + if HAS_STATE_BATCH_INDICES: + mask &= state_batch_idx != pad_slot_id + tl.store(state_ptrs, state, mask=mask) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + pad_slot_id=PAD_SLOT_ID, + out=None, +): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: Preallocated ssm output tensor. Assume same shape as x. + In-place updated. + """ + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + if out.dim() == 2: + out = out.unsqueeze(1) + + _, nheads, dim, dstate = state.shape + batch = x.shape[0] + + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + if state_batch_indices is not None: + assert state_batch_indices.shape == (batch,) + assert out.shape == x.shape + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads) + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ( + (32, 4) + if dstate <= 16 + else ( + (16, 4) + if dstate <= 32 + else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) + ) + ) + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and dt_bias.stride(-1) == 0 + ) + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + pad_slot_id, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + + +def selective_scan_fn( + u, + ssm_states, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + pad_slot_id=PAD_SLOT_ID, +) -> torch.Tensor: + """ + u: (dim, total_length) for varlen or (batch, dim, seqlen) + applies changes in place. + ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate) + applies changes in place. + delta: (dim, total_length) for varlen or (batch, dim, seqlen) + A: (dim, dstate) + B: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + C: (ngroups, dstate, total_length) for varlen or + (batch,ngroups,dstate,seqlen) + D: (dim,) + z: (dim, total_length) for varlen or (batch, dim, seqlen) + dt_bias: (dim,) or (dim) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended with 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + A tensor with each cell is a correspondent + input and output ssm_state index + has_initial_state: (batch) bool + A tensor populated with ones and zeros, + indicate if the ssm_state at the corresponding index should be + used as initial state. Not providing argument assumes + there's no initial state + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padding entries + that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at indices 0 and 3 + returns + output: (dim, total_length) for varlen or (batch, dim, seqlen) + supports inplace replacement + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3 and query_start_loc is None: + B = B.unsqueeze(1) + if B.dim() == 2 and query_start_loc is not None: + B = B.unsqueeze(0) + if C.dim() == 3 and query_start_loc is None: + C = C.unsqueeze(1) + if C.dim() == 2 and query_start_loc is not None: + C = C.unsqueeze(0) + + ops.selective_scan_fwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + pad_slot_id, + ) + + if z is None: + return delta # output written inplace to delta + else: + return z # output written inplace to z diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py new file mode 100644 index 000000000..e618920ce --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py @@ -0,0 +1,264 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_bmm.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py + +# ruff: noqa: E501,SIM102 + +import math + +import torch +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=2, +# ), +# ], +# key=["chunk_size", "K", "IS_CAUSAL"], +# ) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + seq_idx_ptr, + # Matrix dimensions + seqlen, + chunk_size, + K, + ngroups, + stride_a_batch, + stride_a_seqlen, + stride_a_head, + stride_ak, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_bk, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_outm, + stride_outn, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr = 16, + BLOCK_SIZE_N: tl.constexpr = 16, + BLOCK_SIZE_K: tl.constexpr = 16, +): + pid_b = tl.program_id(axis=1) + pid_ch = tl.program_id(axis=2).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + a_ptr += ( + pid_b * stride_a_batch + + pid_c * chunk_size * stride_a_seqlen + + pid_h * stride_a_head + ) + b_ptr += ( + pid_b * stride_b_batch + + pid_c * chunk_size * stride_b_seqlen + + pid_h * stride_b_head + ) + if HAS_SEQ_IDX: + seq_idx_ptr += ( + pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + ) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) + & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + seq_idx_m = tl.load( + seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1, + ) + seq_idx_n = tl.load( + seq_idx_ptr + offs_n * stride_seq_idx_seqlen, + mask=offs_n < chunk_size_limit, + other=-2, + ) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + out = acc.to(out_ptr.dtype.element_ty) + + out_ptr += ( + pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head + ) + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store( + out_ptrs, + out, + mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), + ) + + +def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None): + """ + Argument: + a: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + b: (batch, seqlen, k) or (batch, seqlen, ngroups, k) + seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out. + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size) + """ + # Check constraints. + has_groups = a.dim() == 4 + if not has_groups: + batch, seqlen, k = a.shape + else: + batch, seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if a.stride(-1) != 1 and a.stride(1) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(1) != 1: + b = b.contiguous() + nchunks = math.ceil(seqlen / chunk_size) + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty( + ( + (batch, nchunks, chunk_size, chunk_size) + if not has_groups + else (batch, nchunks, ngroups, chunk_size, chunk_size) + ), + device=a.device, + dtype=out_dtype, + ) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else ( + tl.float16 + if a.dtype == torch.float16 or b.dtype == torch.float16 + else tl.float32 + ) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + batch, + nchunks if not has_groups else nchunks * ngroups, + ) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a, + b, + out, + seq_idx, + seqlen, + chunk_size, + k, + ngroups if has_groups else 1, + a.stride(0), + a.stride(1), + 0 if not has_groups else a.stride(2), + a.stride(-1), + b.stride(0), + b.stride(1), + 0 if not has_groups else b.stride(2), + b.stride(-1), + out.stride(0), + out.stride(1), + 0 if not has_groups else out.stride(2), + out.stride(-2), + out.stride(-1), + *( + (seq_idx.stride(0), seq_idx.stride(1)) + if seq_idx is not None + else (0, 0) + ), + causal, + dot_dtype, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py new file mode 100644 index 000000000..b44f12089 --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py @@ -0,0 +1,622 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py + +# ruff: noqa: E501,SIM102 + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +# @triton.autotune( +# configs=[ +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=2, +# ), +# ], +# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], +# ) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + out_x_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + chunk_indices_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + chunk_size, + hdim, + dstate, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_cb_batch, + stride_cb_chunk, + stride_cb_head, + stride_cb_csize_m, + stride_cb_csize_k, + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_z_batch, + stride_z_seqlen, + stride_z_head, + stride_z_hdim, + stride_out_batch, + stride_out_seqlen, + stride_out_head, + stride_out_hdim, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + stride_C_batch, + stride_C_seqlen, + stride_C_head, + stride_C_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr = 16, + BLOCK_SIZE_N: tl.constexpr = 16, + BLOCK_SIZE_K: tl.constexpr = 16, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + if not HAS_INITSTATES: + c_idx = pid_c + c_off = 0 + else: + c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0) + c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0) + + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += ( + pid_b * stride_cb_batch + + c_idx * stride_cb_chunk + + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ) + x_ptr += ( + pid_b * stride_x_batch + + c_idx * chunk_size * stride_x_seqlen + + pid_h * stride_x_head + ) + dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += ( + pid_b * stride_dA_cs_batch + + c_idx * stride_dA_cs_chunk + + pid_h * stride_dA_cs_head + ) + C_ptr += ( + pid_b * stride_C_batch + + c_idx * chunk_size * stride_C_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_C_head + ) + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + prev_states_ptr = ( + states_ptr + + pid_b * stride_states_batch + + c_idx * stride_states_chunk + + pid_h * stride_states_head + ) + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size) + if HAS_SEQ_IDX: + seq_idx_ptr += ( + pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen + ) + + # - we only need seq_idx_prev to be aligned to chunk boundary + seq_idx_prev = tl.load( + seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0 + ) + + if HAS_INITSTATES: + # if there are init states, we only need seq_idx_m to point + # what is the current seq_idx + + # get current seq idx + if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit: + seq_idx_m = tl.load( + seq_idx_ptr + + (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen, + ) + + # - recall that in ssd_state_passing, for the case c_off == 0 + # i.e., the very first sequence, we made states_ptr hold its initial state + # so this edge case is taken care of + if ( + (c_off == 0) + and ( + seq_idx_prev != seq_idx_m + ) # if a seq is changed exactly on boundary + or (c_off > 0) # implies a new example (pseudo chunk) + ): + + # - replace prev_states_ptr with init_states + prev_states_ptr = ( + initstates_ptr + + seq_idx_m * stride_init_states_batch + + pid_h * stride_init_states_head + ) + prev_states_hdim = stride_init_states_hdim # override strides + prev_states_dstate = stride_init_states_dstate + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) + + # - handle chunk state limit + if HAS_INITSTATES: + + # have to split this if otherwise compilation will have problems + dA_cs_m_boundary = 0.0 + + # get the c_idx for the next (logica) chunk + c_idx_n = tl.load( + chunk_indices_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=-1, # to trigger different chunk + ) + + # - there are things to consider + # A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct + # contribution of past states + # B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to + # encroach into the next sequence, where c_off_n is the offset of the next + # (logical) chunk. + # An equivalent check for B is c_idx == c_idx_n, where there is repetition in + # (logical) chunk indices. + + if (c_idx == c_idx_n) or c_off > 0: + + # get the next offset + c_off_n = tl.load( + chunk_offsets_ptr + (pid_c + 1), + mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num, + other=chunk_size, + ) + + # in this case, adjust down the chunk_size_limit + if c_idx == c_idx_n: + chunk_size_limit = min(c_off_n, chunk_size_limit) + + # get the cs at the offset boundary + # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) + dA_cs_m_boundary = tl.load( + dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, + mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)), + other=0.0, + ).to(tl.float32) + + if HAS_SEQ_IDX: + # - handle seq idx when HAS_INITSTATES==False + if not HAS_INITSTATES: + seq_idx_m = tl.load( + seq_idx_ptr + offs_m * stride_seq_idx_seqlen, + mask=offs_m < chunk_size_limit, + other=-1, + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or c_idx > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange( + 0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K + ) + C_ptrs = C_ptr + ( + offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate + ) + + prev_states_ptrs = prev_states_ptr + ( + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + if HAS_SEQ_IDX: + + if not HAS_INITSTATES: + # - this is for continuous batching where there is no init states + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + else: + # - if there is initstates, we will rely on prev_states, no zeroing + # required. + scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary) + else: + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) + + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) + & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) + & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off + cb_ptrs = cb_ptr + ( + offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k + ) + x_ptrs = x_ptr + ( + offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = ( + chunk_size_limit + if not IS_CAUSAL + else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + ) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to( + tl.float32 + ) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :]) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load( + D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0 + ).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load( + x_ptr + + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += ( + pid_b * stride_out_batch + + c_idx * chunk_size * stride_out_seqlen + + pid_h * stride_out_head + ) + out_x_ptrs = out_x_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] + ) + tl.store( + out_x_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + ) + + z_ptr += ( + pid_b * stride_z_batch + + c_idx * chunk_size * stride_z_seqlen + + pid_h * stride_z_head + ) + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) + & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += ( + pid_b * stride_out_batch + + c_idx * chunk_size * stride_out_seqlen + + pid_h * stride_out_head + ) + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, + acc, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + ) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + D=None, + z=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + initial_states=None, + out=None, +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + + if initial_states is not None: + # with initial states, we need to take care of how + # seq_idx crosses the boundaries + assert batch == 1, "chunk scan only supports initial states with batch 1" + assert ( + chunk_indices is not None and chunk_offsets is not None + ), "chunk_indices and chunk_offsets should have been set" + else: + chunk_indices, chunk_offsets = None, None + else: + chunk_indices, chunk_offsets = None, None + + assert out.shape == x.shape + + if z is not None: + out_x = torch.empty_like(x) + assert out_x.stride() == out.stride() + else: + out_x = None + + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + batch * nchunks if chunk_offsets is None else len(chunk_offsets), + nheads, + ) + z_strides = ( + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None + else (0, 0, 0, 0) + ) + _chunk_scan_fwd_kernel[grid]( + cb, + x, + z, + out, + out_x, + dt, + dA_cumsum, + seq_idx, + C, + states, + D, + initial_states, + chunk_indices, + chunk_offsets, + len(chunk_indices) if chunk_indices is not None else 0, + chunk_size, + headdim, + dstate, + batch, + seqlen, + nheads // ngroups, + cb.stride(0), + cb.stride(1), + cb.stride(2), + cb.stride(3), + cb.stride(4), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + *( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return out_x diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py new file mode 100644 index 000000000..fc3946763 --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py @@ -0,0 +1,757 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py + +# ruff: noqa: E501 + +import math + +import torch +import triton +import triton.language as tl + +from .mamba_ssm import softplus + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_SIZE_H": 2}), +# triton.Config({"BLOCK_SIZE_H": 4}), +# triton.Config({"BLOCK_SIZE_H": 8}), +# triton.Config({"BLOCK_SIZE_H": 16}), +# triton.Config({"BLOCK_SIZE_H": 32}), +# triton.Config({"BLOCK_SIZE_H": 64}), +# ], +# key=["chunk_size", "nheads"], +# ) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + # Matrix dimension + batch, + seqlen, + nheads, + chunk_size, + dt_min, + dt_max, + # Strides + stride_dt_batch, + stride_dt_seqlen, + stride_dt_head, + stride_A_head, + stride_dt_bias_head, + stride_dt_out_batch, + stride_dt_out_chunk, + stride_dt_out_head, + stride_dt_out_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr = 16, +): + pid_b = tl.program_id(axis=0) + + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen + dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + ( + offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen + ) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + dt = tl.load( + dt_ptrs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), + other=0.0, + ).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + # As of Triton 2.2.0, tl.clamp is not available yet + # dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.minimum(tl.maximum(dt, dt_min), dt_max) + dt = tl.where( + (offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0 + ) + tl.store( + dt_out_ptrs, + dt, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store( + dA_cs_ptrs, + dA_cs, + mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size), + ) + + +# @triton.autotune( +# configs=[ +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=2, +# ), +# ], +# key=["hdim", "dstate", "chunk_size"], +# ) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + batch, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_batch, + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_batch, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_dt_batch, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr = 16, + BLOCK_SIZE_N: tl.constexpr = 16, + BLOCK_SIZE_K: tl.constexpr = 16, +): + pid_bc = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + b_ptr += ( + pid_b * stride_b_batch + + pid_c * chunk_size * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += ( + pid_b * stride_x_batch + + pid_c * chunk_size * stride_x_seqlen + + pid_h * stride_x_head + ) + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += ( + pid_b * stride_dA_cs_batch + + pid_c * stride_dA_cs_chunk + + pid_h * stride_dA_cs_head + ) + if HAS_SEQ_IDX: + seq_idx_ptr += ( + pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + ) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to( + tl.float32 + ) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_last = tl.load( + seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen + ) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_k = tl.load( + seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1 + ) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k + else: + scale = tl.where( + seq_idx_k == seq_idx_last, tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0 + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += ( + pid_b * stride_states_batch + + pid_c * stride_states_chunk + + pid_h * stride_states_head + ) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +# @triton.autotune( +# configs=[ +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, +# num_stages=3, +# num_warps=8, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=4, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=5, +# num_warps=2, +# ), +# triton.Config( +# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, +# num_stages=4, +# num_warps=2, +# ), +# ], +# key=["hdim", "dstate", "chunk_size"], +# ) +@triton.jit +def _chunk_state_varlen_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + states_ptr, + initstates_ptr, + # Matrix dimensions + hdim, + dstate, + chunk_size, + seqlen, + nheads_ngroups_ratio, + # Strides + stride_x_seqlen, + stride_x_head, + stride_x_hdim, + stride_b_seqlen, + stride_b_head, + stride_b_dstate, + stride_dt_chunk, + stride_dt_head, + stride_dt_csize, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_chunk_states_chunk, + stride_chunk_states_head, + stride_chunk_states_hdim, + stride_chunk_states_dstate, + stride_states_batch, + stride_states_head, + stride_states_hdim, + stride_states_dstate, + stride_init_states_batch, + stride_init_states_head, + stride_init_states_hdim, + stride_init_states_dstate, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr = 16, + BLOCK_SIZE_N: tl.constexpr = 16, + BLOCK_SIZE_K: tl.constexpr = 16, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + pid_c = (end_idx - 1) // chunk_size + b_ptr += ( + pid_c * chunk_size * stride_b_seqlen + + (pid_h // nheads_ngroups_ratio) * stride_b_head + ) + x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += ( + pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + ) + + if HAS_INITSTATES: + # if there are init states provided, we differentiate between states (which + # are boundary conditions at a chunk boundary) and initstates (which are boundary + # conditions when a new example in a cont batch starts) + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + ( + offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen + ) + b_ptrs = b_ptr + ( + offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen + ) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load( + dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = end_idx - pid_c * chunk_size + start_idx = tl.load(cu_seqlens_ptr + pid_b) + start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load( + dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0 + ).to(tl.float32) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(dA_cs_last - dA_cs_k) * dt_k, + 0.0, + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + # If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk + # If HAS_INITSTATES==True need to consider two possibilities + # - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs + # - if state_idx >= pid * chunk_size, then we need to insert initstates + if (start_idx < pid_c * chunk_size) or (HAS_INITSTATES): # first chunk + + dA_cs_boundary = 0.0 # default + + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + + # - this seems repetitive, buts its to help the compiler + if start_idx < pid_c * chunk_size: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) + + # need to adjust the boundary + if start_idx > pid_c * chunk_size: + dA_cs_boundary = tl.load( + dA_cumsum_ptr + + (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize + ).to(tl.float32) + + past_states = tl.load( + past_states_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + + scale = tl.exp(dA_cs_last - dA_cs_boundary) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd( + dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf")) +): + batch, seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = math.ceil(seqlen / chunk_size) + dt_out = torch.empty( + batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + dA_cumsum = torch.empty( + batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32 + ) + grid_chunk_cs = lambda META: ( + batch, + nchunks, + triton.cdiv(nheads, META["BLOCK_SIZE_H"]), + ) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt, + A, + dt_bias, + dt_out, + dA_cumsum, + batch, + seqlen, + nheads, + chunk_size, + dt_limit[0], + dt_limit[1], + dt.stride(0), + dt.stride(1), + dt.stride(2), + A.stride(0), + dt_bias.stride(0) if dt_bias is not None else 0, + dt_out.stride(0), + dt_out.stride(2), + dt_out.stride(1), + dt_out.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd( + B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True +): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if states is not None: + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty( + (batch, nchunks, nheads, headdim, dstate), + device=x.device, + dtype=states_dtype, + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch * nchunks, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x, + B, + states, + dt, + dA_cumsum, + seq_idx, + headdim, + dstate, + chunk_size, + batch, + seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(-1), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + states.stride(4), + dt.stride(0), + dt.stride(2), + dt.stride(1), + dt.stride(3), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *( + (seq_idx.stride(0), seq_idx.stride(1)) + if seq_idx is not None + else (0, 0) + ), + HAS_SEQ_IDX=seq_idx is not None, + ) + return states + + +def chunk_state_varlen( + B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None +): + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + + states = torch.empty( + batch, + nheads, + headdim, + dstate, + dtype=chunk_states.dtype, + device=chunk_states.device, + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) + * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x, + B, + dt, + dA_cumsum, + chunk_states, + cu_seqlens, + states, + initial_states, + headdim, + dstate, + chunk_size, + total_seqlen, + nheads // ngroups, + x.stride(0), + x.stride(1), + x.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + dt.stride(1), + dt.stride(0), + dt.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + chunk_states.stride(0), + chunk_states.stride(1), + chunk_states.stride(2), + chunk_states.stride(3), + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + *( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ), + HAS_INITSTATES=initial_states is not None, + ) + return states diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py new file mode 100644 index 000000000..d27fc562e --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py @@ -0,0 +1,262 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl +from einops import rearrange +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +def is_int_pow_2(n): + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, + out=None, +): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (batch, seqlen, ngroups, dstate) + assert dt.shape == (batch, seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if ( + x.stride(-1) != 1 and x.stride(1) != 1 + ): # Either M or K dimension should be contiguous + x = x.contiguous() + if ( + z is not None and z.stride(-1) != 1 and z.stride(1) != 1 + ): # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + if initial_states is not None: + if cu_seqlens is None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + else: + assert initial_states.shape == ( + len(cu_seqlens) - 1, + nheads, + headdim, + dstate, + ) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) + # - this will ensure that states will be updated with the rightmost flushed seq_idx + # of the previous chunk. This implies that the first chunk of states is either 0 + # or equal to init_states of the first example. + states, final_states = _state_passing_fwd( + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum, + initial_states=( + rearrange(initial_states, "... p n -> ... (p n)") + if initial_states is not None + else None + ), + seq_idx=seq_idx, + chunk_size=chunk_size, + out_dtype=state_dtype if state_dtype is not None else C.dtype, + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets, + ) + states, final_states = ( + rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] + ) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + out_x = _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + D=D, + z=z, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + initial_states=initial_states, + out=out, + ) + if cu_seqlens is None: + return out_x, dt, dA_cumsum, states, final_states + else: + assert ( + batch == 1 + ), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1" + varlen_states = chunk_state_varlen( + B.squeeze(0), + x.squeeze(0), + dt.squeeze(0), + dA_cumsum.squeeze(0), + cu_seqlens, + states.squeeze(0), + initial_states=initial_states, + ) + return out_x, dt, dA_cumsum, states, final_states, varlen_states + + +def mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + chunk_size, + D=None, + z=None, + dt_bias=None, + initial_states=None, + seq_idx=None, + chunk_indices=None, + chunk_offsets=None, + cu_seqlens=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + out=None, + return_final_states=False, + return_varlen_states=False, + state_dtype=None, +): + """ + Argument: + x: (batch, seqlen, nheads, headdim) + dt: (batch, seqlen, nheads) + A: (nheads) + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + chunk_size: int + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + seq_idx: (batch, seqlen) + cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True + dt_softplus: Whether to apply softplus to dt + out: Preallocated output tensor + state_dtype: The data type of the ssm state + """ + + if not return_varlen_states: + cu_seqlens = None + else: + assert ( + cu_seqlens is not None + ), "cu_seqlens must be provided if return_varlen_states is True" + out_x, dt_out, dA_cumsum, states, final_states, *rest = ( + _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + seq_idx=seq_idx, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + cu_seqlens=cu_seqlens, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + out=out, + state_dtype=state_dtype, + ) + ) + if not return_varlen_states: + if not return_final_states: + return + else: + return final_states + else: + varlen_states = rest[0] + return ( + (varlen_states) + if not return_final_states + else (final_states, varlen_states) + ) diff --git a/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py b/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py new file mode 100644 index 000000000..f0a8e0f6b --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py @@ -0,0 +1,275 @@ +# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py + +# ruff: noqa: E501 + +import torch +import triton +import triton.language as tl + + +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_SIZE": 64}), +# triton.Config({"BLOCK_SIZE": 128}), +# triton.Config({"BLOCK_SIZE": 256}), +# triton.Config({"BLOCK_SIZE": 512}), +# triton.Config({"BLOCK_SIZE": 1024}), +# triton.Config({"BLOCK_SIZE": 2048}), +# ], +# key=["dim"], +# ) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + final_states_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, + # Matrix dimensions + dim, + nchunks, + seqlen, + chunk_size, + # Strides + stride_states_batch, + stride_states_chunk, + stride_states_head, + stride_states_dim, + stride_out_batch, + stride_out_chunk, + stride_out_head, + stride_out_dim, + stride_final_states_batch, + stride_final_states_head, + stride_final_states_dim, + stride_dA_cs_batch, + stride_dA_cs_chunk, + stride_dA_cs_head, + stride_dA_cs_csize, + stride_initstates_batch, + stride_initstates_head, + stride_initstates_dim, + stride_seq_idx_batch, + stride_seq_idx_seqlen, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + IS_CONT_BATCHED: tl.constexpr, + BLOCK_SIZE: tl.constexpr = 16, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + dA_cs_ptr += ( + pid_b * stride_dA_cs_batch + + pid_h * stride_dA_cs_head + + (chunk_size - 1) * stride_dA_cs_csize + ) + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + final_states_ptr += ( + pid_b * stride_final_states_batch + pid_h * stride_final_states_head + ) + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_initstates_head + if not IS_CONT_BATCHED: + initstates_ptr += pid_b * stride_initstates_batch + + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim + + # - states will be the past state of the sequence that continues on the current check + if not HAS_INITSTATES: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + else: + initstates_ptr += offs_m * stride_initstates_dim + initstates_ptrs = initstates_ptr + # - for cont batches, for the first chunk mean it will be the first batch's + # init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + tl.store(out_ptrs, states, mask=offs_m < dim) + out_ptrs += stride_out_chunk + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + scale_mask = True + if HAS_SEQ_IDX: + # - the seq to pass forward is the one that is flushed to the right + # boundary. + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load( + seq_idx_ptr + + (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen + ) + if HAS_INITSTATES: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: + # this means in the current chunk the rightmost flushed seq + # has changed. + # - so we do not propagate the state from previous chunk + # - but rather we load that sequence's init state + initstates_ptrs = ( + initstates_ptr + seq_idx_chunk_end * stride_initstates_batch + ) + + # - update state with seq_idx_new's init state + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to( + tl.float32 + ) + + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load( + seq_idx_ptr + + min(c * chunk_size, seqlen) * stride_seq_idx_seqlen + ) + logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start + # - load the chunk offset: + c_off = tl.load( + chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0, + ) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr + - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0, + ) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 + else: + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end + + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) + states = scale * states + new_states + if c < nchunks - 1: + tl.store(out_ptrs, states, mask=offs_m < dim) + else: + tl.store(final_states_ptrs, states, mask=offs_m < dim) + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, + dA_cumsum, + initial_states=None, + seq_idx=None, + chunk_size=None, + out_dtype=None, + is_cont_batched=False, + chunk_offsets=None, +): + batch, nchunks, nheads, dim = states.shape + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if initial_states is not None: + if is_cont_batched: + # - if cu_seqlens is provided, then the initial states + # are used for continuous batching. In which case we + # require seq_idx to be provided + assert ( + seq_idx is not None + ), "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert ( + chunk_offsets is not None + ), "chunk_offsets must be provided for continuous batching" + else: + # - this is the regular batching case, where initial + # states are used are for each example of the batch. + assert initial_states.shape == (batch, nheads, dim) + + if seq_idx is not None: + seqlen = seq_idx.shape[-1] + assert seq_idx.shape == (batch, seqlen) + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty( + (batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype + ) + final_states = torch.empty( + (batch, nheads, dim), device=states.device, dtype=torch.float32 + ) + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states, + out, + final_states, + dA_cumsum, + initial_states, + seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, + dim, + nchunks, + seqlen if seq_idx is not None else 0, + chunk_size, + states.stride(0), + states.stride(1), + states.stride(2), + states.stride(3), + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + final_states.stride(0), + final_states.stride(1), + final_states.stride(2), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), + *( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + ) + if initial_states is not None + else (0, 0, 0) + ), + *( + (seq_idx.stride(0), seq_idx.stride(1)) + if seq_idx is not None + else (0, 0) + ), + HAS_INITSTATES=initial_states is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_CONT_BATCHED=is_cont_batched, + ) + return out, final_states diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index fb1305c31..0126cd180 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1297,6 +1297,7 @@ class ModelRunner: return self.model_config.hf_config.architectures[0] in [ "Qwen3NextForCausalLM", "Qwen3NextForCausalLMMTP", + "FalconH1ForCausalLM", ] def set_num_token_hybrid(self): diff --git a/python/sglang/srt/models/falcon_h1.py b/python/sglang/srt/models/falcon_h1.py new file mode 100644 index 000000000..a035e0291 --- /dev/null +++ b/python/sglang/srt/models/falcon_h1.py @@ -0,0 +1,576 @@ +import enum +import logging +from typing import Any, Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn + +from sglang.srt.configs.falcon_h1 import FalconH1Config +from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.attention.mamba.mamba import MambaMixer2 +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.utils import add_prefix, is_cuda, make_layers + +logger = logging.getLogger(__name__) +_is_cuda = is_cuda() + + +class FalconH1MLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + layer_id: int, + mlp_multipliers: List[float], + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + reduce_results=reduce_results, + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + self.layer_id = layer_id + + self.intermediate_size = intermediate_size + self.tp_size = get_tensor_model_parallel_world_size() + + self.gate_multiplier, self.down_multiplier = mlp_multipliers + + def forward( + self, + x, + forward_batch=None, + use_reduce_scatter: bool = False, + ): + gate_up, _ = self.gate_up_proj(x) + gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier + + x = self.act_fn(gate_up) + x, _ = self.down_proj( + x, + skip_all_reduce=use_reduce_scatter, + ) + x = x * self.down_multiplier + return x + + +class FalconH1HybridAttentionDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconH1Config, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.attn_tp_rank = get_attention_tp_rank() + self.attn_tp_size = get_attention_tp_size() + self.tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % self.attn_tp_size == 0 + self.num_heads = self.total_num_heads // self.attn_tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= self.attn_tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % self.attn_tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert self.attn_tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size) + self.head_dim = config.head_dim or (self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = getattr(config, "rope_theta", 10000) + self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.rope_scaling = getattr(config, "rope_scaling", None) + self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1) + self.layer_id = layer_id + + self.rotary_emb = get_rope( + head_size=self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + rope_scaling=self.rope_scaling, + base=self.rope_theta, + partial_rotary_factor=self.partial_rotary_factor, + is_neox_style=True, + dtype=torch.get_default_dtype(), # see impl of get_rope + ) + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=False, + tp_rank=self.attn_tp_rank, + tp_size=self.attn_tp_size, + ) + + self.attn = RadixAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + prefix=f"{prefix}.attn", + ) + + self.d_ssm = ( + int(config.mamba_expand * config.hidden_size) + if config.mamba_d_ssm is None + else config.mamba_d_ssm + ) + + self.mamba = MambaMixer2( + hidden_size=config.hidden_size, + ssm_state_size=config.mamba_d_state, + conv_kernel_size=config.mamba_d_conv, + intermediate_size=self.d_ssm, + use_conv_bias=config.mamba_conv_bias, + use_bias=config.mamba_proj_bias, + n_groups=config.mamba_n_groups, + num_heads=config.mamba_n_heads, + layer_id=layer_id, + head_dim=config.mamba_d_head, + rms_norm_eps=config.rms_norm_eps, + chunk_size=config.mamba_chunk_size, + activation=config.hidden_act, + use_rms_norm=config.mamba_rms_norm, + prefix=f"{prefix}.mixer", + ) + + # FalconH1 all layers are sparse and have no nextn now + self.is_layer_sparse = False + is_previous_layer_sparse = False + + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + self.feed_forward = FalconH1MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + layer_id=layer_id, + mlp_multipliers=config.mlp_multipliers, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.pre_ff_layernorm, + allow_reduce_scatter=True, + ) + + self.alt_stream = alt_stream + self.key_multiplier = config.key_multiplier + + self.ssm_out_multiplier = config.ssm_out_multiplier + self.ssm_in_multiplier = config.ssm_in_multiplier + + self.attention_in_multiplier = config.attention_in_multiplier + self.attn_out_multiplier = config.attention_out_multiplier + + self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state + self.zxbcdt_multipliers = config.ssm_multipliers + self._init_mup_vector() + + def _init_mup_vector(self): + """ + Non learnable per-block scaling vector composed of element-wise + multipliersapplied to each separate contiguous block of the output + of the linear projection (in_proj) before further processing + (gating, convolution, SSM): + + - Z block: [0 : d_ssm] → zxbcdt_multipliers[0] + - X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1] + - B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2] + - C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S] + → zxbcdt_multipliers[3] + - dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4] + + where: + - d_ssm: Dimension of state-space model latent + - G: Number of groups (n_groups) + - S: SSM state size per group + - All indices are divided by tp_size to support tensor parallelism + """ + vector_shape = ( + 2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads + ) // self.tp_size + mup_vector = torch.ones(1, vector_shape) + # Z vector 0 -> d_ssm + mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0] + # X vector d_ssm -> 2 * d_ssm + mup_vector[ + :, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size) + ] *= self.zxbcdt_multipliers[1] + # B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm) + // self.tp_size : (2 * self.d_ssm + self.groups_time_state_size) + // self.tp_size, + ] *= self.zxbcdt_multipliers[2] + # C vector 2 * d_ssm + (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + mup_vector[ + :, + (2 * self.d_ssm + self.groups_time_state_size) + // self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size) + // self.tp_size, + ] *= self.zxbcdt_multipliers[3] + # dt vector 2 * d_ssm + 2 * (n_group * d_state) + # -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads + mup_vector[ + :, + (2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :, + ] *= self.zxbcdt_multipliers[4] + + self.register_buffer("mup_vector", mup_vector, persistent=False) + + def self_attention( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + k = k * self.key_multiplier + q, k = self.rotary_emb(positions, q, k) + + attn_output = self.attn(q, k, v, forward_batch) + + output, _ = self.o_proj(attn_output) + return output + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: ForwardBatch, + **kwargs: Any, + ): + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + + if not forward_batch.forward_mode.is_idle(): + # Attention block + attention_hidden_states = self.self_attention( + positions=positions, + hidden_states=hidden_states * self.attention_in_multiplier, + forward_batch=forward_batch, + ) + attention_hidden_states = attention_hidden_states * self.attn_out_multiplier + + # Mamba block + mamba_hidden_states = torch.empty_like(hidden_states) + self.mamba( + hidden_states * self.ssm_in_multiplier, + mamba_hidden_states, + forward_batch=forward_batch, + mup_vector=self.mup_vector, + ) + mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier + + hidden_states = attention_hidden_states + mamba_hidden_states + + # Fully Connected + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + hidden_states = self.feed_forward( + hidden_states, forward_batch, use_reduce_scatter + ) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +ALL_DECODER_LAYER_TYPES = { + "falcon_h1": FalconH1HybridAttentionDecoderLayer, +} + + +class FalconH1Model(nn.Module): + def __init__( + self, + config: FalconH1Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + + alt_stream = torch.cuda.Stream() if _is_cuda else None + self.embedding_multiplier = config.embedding_multiplier + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + enable_tp=not is_dp_attention_enabled(), + ) + + def get_layer(idx: int, prefix: str): + layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]] + return layer_class( + config, + idx, + quant_config=quant_config, + prefix=prefix, + alt_stream=alt_stream, + ) + + self.layers = make_layers( + config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers" + ) + + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.infer_count = 0 + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + # mamba_cache_params: MambaCacheParams, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # pass a sequence index tensor, that is required for + # proper continuous batching computation including + # chunked prefill + if inputs_embeds is not None: + hidden_states = inputs_embeds * self.embedding_multiplier + else: + hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + layer_id=i, + positions=positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + if not forward_batch.forward_mode.is_idle(): + if residual is None: + hidden_states = self.final_layernorm(hidden_states) + else: + hidden_states, _ = self.final_layernorm(hidden_states, residual) + + return hidden_states + + +class HybridLayerType(enum.Enum): + full_attention = "attention" + swa_attention = "swa_attention" + linear_attention = "linear_attention" + mamba2 = "mamba" + + +class FalconH1ForCausalLM(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: FalconH1Config, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.pp_group = get_pp_group() + assert self.pp_group.is_first_rank and self.pp_group.is_last_rank + self.quant_config = quant_config + self.model = FalconH1Model( + config, quant_config, prefix=add_prefix("model", prefix) + ) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + org_num_embeddings=config.vocab_size, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.lm_head = self.lm_head.float() + self.lm_head_multiplier = config.lm_head_multiplier + self.logits_processor = LogitsProcessor( + config, logit_scale=self.lm_head_multiplier + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, + ): + hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_weights( + self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False + ) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + + if "rotary_emb.inv_freq" in name: + continue + + if ".self_attn." in name: + name = name.replace(".self_attn", "") + + if "A_log" in name: + name = name.replace("A_log", "A") + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + # if is_pp_missing_parameter(name, self): + # continue + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader") + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # if is_pp_missing_parameter(name, self): + # continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + + weight_loader(param, loaded_weight) + + loaded_params.add(name) + return loaded_params + + +EntryClass = FalconH1ForCausalLM diff --git a/test/srt/models/test_falcon_h1_models.py b/test/srt/models/test_falcon_h1_models.py new file mode 100644 index 000000000..cb32a7ef1 --- /dev/null +++ b/test/srt/models/test_falcon_h1_models.py @@ -0,0 +1,147 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestFalconH1(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "tiiuae/Falcon-H1-0.5B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tensor-parallel-size", + "1", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.74) + + +class TestFalconH1TP4(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "tiiuae/Falcon-H1-0.5B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tensor-parallel-size", + "4", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.74) + + +class TestFalconH1NoGatedRMS(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "tiiuae/Falcon-H1-1.5B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tensor-parallel-size", + "1", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.74) + + +class TestFalconH1NoGatedTP4(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = "tiiuae/Falcon-H1-1.5B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--tensor-parallel-size", + "4", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.74) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 71862a7e8..e474feeeb 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -140,6 +140,7 @@ suites = { TestFile("test_local_attn.py", 250), TestFile("test_pp_single_node.py", 372), TestFile("models/test_qwen3_next_models.py", 200), + TestFile("models/test_falcon_h1_models.py", 200), TestFile("test_multi_instance_release_memory_occupation.py", 64), ], "per-commit-8-gpu": [