[model] Add mamba2 and Falcon-H1 support. (#10988)
Co-authored-by: Younes Belkada <younes.belkada@tii.ae> Co-authored-by: Younes B <49240599+younesbelkada@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.dots_ocr import DotsOCRConfig
|
||||
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
@@ -30,4 +31,5 @@ __all__ = [
|
||||
"Qwen3NextConfig",
|
||||
"DotsVLMConfig",
|
||||
"DotsOCRConfig",
|
||||
"FalconH1Config",
|
||||
]
|
||||
|
||||
360
python/sglang/srt/configs/falcon_h1.py
Normal file
360
python/sglang/srt/configs/falcon_h1.py
Normal file
@@ -0,0 +1,360 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 TII and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Falcon-H1 model configuration"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FalconH1Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`FalconH1Model`]. It is used to instantiate a
|
||||
FalconH1Model model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
||||
with defaults taken from [ibm-fms/FalconH1-9.8b-2.2T-hf](https://huggingface.co/ibm-fms/FalconH1-9.8b-2.2T-hf).
|
||||
The FalconH1Model is a hybrid [mamba2](https://github.com/state-spaces/mamba) architecture with SwiGLU.
|
||||
The checkpoints are jointly trained by IBM, Princeton, and UIUC.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 128000):
|
||||
Vocabulary size of the FalconH1 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`FalconH1Model`]
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the
|
||||
model has a output word embedding layer.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 14336):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 8):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `8`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
|
||||
Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
|
||||
integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
|
||||
logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
|
||||
sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
|
||||
significantly.
|
||||
pad_token_id (`int`, *optional*, defaults to 0):
|
||||
The id of the padding token.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
The id of the "beginning-of-sequence" token.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the "end-of-sequence" token.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8192):
|
||||
Max cached sequence length for the model
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
mamba_d_ssm (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the SSM state space latents.
|
||||
mamba_n_heads (`int`, *optional*, defaults to 128):
|
||||
The number of mamba heads used in the v2 implementation.
|
||||
mamba_d_head (`int`, *optional*, defaults to `"auto"`):
|
||||
Head embedding dimension size
|
||||
mamba_n_groups (`int`, *optional*, defaults to 1):
|
||||
The number of the mamba groups used in the v2 implementation.
|
||||
mamba_d_state (`int`, *optional*, defaults to 256):
|
||||
The dimension the mamba state space latents
|
||||
mamba_d_conv (`int`, *optional*, defaults to 4):
|
||||
The size of the mamba convolution kernel
|
||||
mamba_expand (`int`, *optional*, defaults to 2):
|
||||
Expanding factor (relative to hidden_size) used to determine the mamba intermediate size
|
||||
mamba_chunk_size (`int`, *optional*, defaults to 256):
|
||||
The chunks in which to break the sequence when doing prefill/training
|
||||
mamba_conv_bias (`bool`, *optional*, defaults to `True`):
|
||||
Flag indicating whether or not to use bias in the convolution layer of the mamba mixer block.
|
||||
mamba_proj_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the mamba mixer block
|
||||
mamba_norm_before_gate (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use RMSNorm before the gate in the Mamba block
|
||||
mamba_rms_norm (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use RMSNorm instead of LayerNorm in the Mamba block
|
||||
projectors_bias (`bool`, *optional*, defaults to `False`):
|
||||
Flag indicating whether or not to use bias in the input and output projections (["in_proj", "out_proj"]) of the attention block
|
||||
rope_theta (`float`, *optional*, defaults to 100000.0):
|
||||
The theta value used for the RoPE embeddings.
|
||||
rope_scaling (`float`, *optional*):
|
||||
The scaling value used for the RoPE embeddings. If `None`, no scaling is applied.
|
||||
lm_head_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The multiplier for the LM head. This is used to scale the output of the LM head.
|
||||
embedding_multiplier (`float`, *optional*, defaults to 1.0):
|
||||
The multiplier for the embedding layer. This is used to scale the output of the embedding layer.
|
||||
mlp_multipliers (`list[float]`, *optional*):
|
||||
The multipliers for the MLP layers. This is used to scale the output of the MLP layers. The first value is
|
||||
the multiplier of gate layer, the second value is the multiplier of the down_proj layer.
|
||||
key_multiplier (`float`, *optional*):
|
||||
The multiplier for the key layer. This is used to scale the output of the key layer.
|
||||
attention_out_multiplier (`float`, *optional*):
|
||||
The multiplier for the attention output layer. This is used to scale the output of the attention output
|
||||
attention_in_multiplier (`float`, *optional*):
|
||||
The multiplier for the attention input layer. This is used to scale the output of the attention input layer.
|
||||
ssm_multipliers (`list[float]`, *optional*):
|
||||
The multipliers for the SSM layers. This is used to scale the output of the SSM layers.
|
||||
ssm_in_multiplier (`float`, *optional*):
|
||||
The multiplier for the SSM input layer. This is used to scale the output of the SSM input layer.
|
||||
ssm_out_multiplier (`float`, *optional*):
|
||||
The multiplier for the SSM output layer. This is used to scale the output of the SSM output layer.
|
||||
"""
|
||||
|
||||
model_type = "falcon_h1"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=128000,
|
||||
tie_word_embeddings=False,
|
||||
hidden_size=4096,
|
||||
intermediate_size=14336,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=8,
|
||||
hidden_act="silu",
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-5,
|
||||
use_cache=True,
|
||||
num_logits_to_keep=1,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
max_position_embeddings=8192,
|
||||
attention_dropout=0.0,
|
||||
mamba_d_ssm=1024,
|
||||
mamba_n_heads=128,
|
||||
mamba_d_head="auto",
|
||||
mamba_n_groups=1,
|
||||
mamba_d_state=256,
|
||||
mamba_d_conv=4,
|
||||
mamba_expand=2,
|
||||
mamba_chunk_size=256,
|
||||
mamba_conv_bias=True,
|
||||
mamba_proj_bias=False,
|
||||
mamba_norm_before_gate=True,
|
||||
mamba_rms_norm=False,
|
||||
projectors_bias=False,
|
||||
rope_theta=100000.0,
|
||||
rope_scaling=None,
|
||||
lm_head_multiplier=1.0,
|
||||
embedding_multiplier=1.0,
|
||||
mlp_multipliers=None,
|
||||
key_multiplier=None,
|
||||
attention_out_multiplier=None,
|
||||
attention_in_multiplier=None,
|
||||
ssm_multipliers=None,
|
||||
ssm_in_multiplier=None,
|
||||
ssm_out_multiplier=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.attention_dropout = attention_dropout
|
||||
self.attention_bias = False
|
||||
self.mlp_bias = False
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
|
||||
self.use_cache = use_cache
|
||||
self.num_logits_to_keep = num_logits_to_keep
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = None
|
||||
self.rope_scaling = rope_scaling
|
||||
self.projectors_bias = projectors_bias
|
||||
mamba_intermediate = (
|
||||
mamba_expand * hidden_size if mamba_d_ssm is None else mamba_d_ssm
|
||||
)
|
||||
|
||||
if mamba_intermediate % mamba_n_heads != 0:
|
||||
raise ValueError("mamba_n_heads must divide mamba_expand * hidden_size")
|
||||
|
||||
# for the mamba_v2, must satisfy the following
|
||||
if mamba_d_head == "auto":
|
||||
mamba_d_head = mamba_intermediate // mamba_n_heads
|
||||
|
||||
if mamba_d_head * mamba_n_heads != mamba_intermediate:
|
||||
raise ValueError(
|
||||
"The dimensions for the Mamba head state do not match the model intermediate_size"
|
||||
)
|
||||
|
||||
self.mamba_d_ssm = mamba_d_ssm
|
||||
self.mamba_n_heads = mamba_n_heads
|
||||
self.mamba_d_head = mamba_d_head
|
||||
self.mamba_n_groups = mamba_n_groups
|
||||
self.mamba_d_state = mamba_d_state
|
||||
self.mamba_d_conv = mamba_d_conv
|
||||
self.mamba_expand = mamba_expand
|
||||
self.mamba_chunk_size = mamba_chunk_size
|
||||
self.mamba_conv_bias = mamba_conv_bias
|
||||
self.mamba_proj_bias = mamba_proj_bias
|
||||
|
||||
self.mamba_norm_before_gate = mamba_norm_before_gate
|
||||
self.mamba_rms_norm = mamba_rms_norm
|
||||
|
||||
self.lm_head_multiplier = lm_head_multiplier
|
||||
self.embedding_multiplier = embedding_multiplier
|
||||
|
||||
if mlp_multipliers is not None:
|
||||
self.mlp_multipliers = mlp_multipliers
|
||||
else:
|
||||
self.mlp_multipliers = [1.0, 1.0]
|
||||
|
||||
if attention_out_multiplier is not None:
|
||||
self.attention_out_multiplier = attention_out_multiplier
|
||||
else:
|
||||
self.attention_out_multiplier = 1.0
|
||||
|
||||
if attention_in_multiplier is not None:
|
||||
self.attention_in_multiplier = attention_in_multiplier
|
||||
else:
|
||||
self.attention_in_multiplier = 1.0
|
||||
|
||||
if key_multiplier is not None:
|
||||
self.key_multiplier = key_multiplier
|
||||
else:
|
||||
self.key_multiplier = 1.0
|
||||
|
||||
if ssm_multipliers is not None:
|
||||
self.ssm_multipliers = ssm_multipliers
|
||||
else:
|
||||
self.ssm_multipliers = [1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
if ssm_in_multiplier is not None:
|
||||
self.ssm_in_multiplier = ssm_in_multiplier
|
||||
else:
|
||||
self.ssm_in_multiplier = 1.0
|
||||
|
||||
if ssm_out_multiplier is not None:
|
||||
self.ssm_out_multiplier = ssm_out_multiplier
|
||||
else:
|
||||
self.ssm_out_multiplier = 1.0
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def layers_block_type(self):
|
||||
return ["falcon_h1" for i in range(self.num_hidden_layers)]
|
||||
|
||||
@property
|
||||
def mamba_cache_per_req(self):
|
||||
conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
|
||||
self.hybrid_gdn_params
|
||||
)
|
||||
mamba_layers_len = len(mamba_layers)
|
||||
|
||||
return (
|
||||
int(np.prod(conv_state_shape)) * conv_dtype.itemsize
|
||||
+ int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
|
||||
) * mamba_layers_len
|
||||
|
||||
@property
|
||||
def full_attention_layer_ids(self):
|
||||
# For Falcon-H1, we do have attention on all layers
|
||||
return range(self.num_hidden_layers)
|
||||
|
||||
@property
|
||||
def linear_layer_ids(self):
|
||||
# For Falcon-H1, we do have mamba on all layers
|
||||
return range(self.num_hidden_layers)
|
||||
|
||||
@property
|
||||
def hybrid_gdn_params(self):
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
n_groups = self.mamba_n_groups
|
||||
if self.mamba_n_groups % world_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
extra_groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
self.mamba_n_groups, world_size
|
||||
)
|
||||
n_groups += extra_groups
|
||||
|
||||
conv_dim = self.mamba_d_ssm + 2 * n_groups * self.mamba_d_state
|
||||
|
||||
conv_state_shape = (
|
||||
divide(conv_dim, world_size),
|
||||
self.mamba_d_conv - 1,
|
||||
)
|
||||
|
||||
# we TP-ize on the heads dimension
|
||||
temporal_state_shape = (
|
||||
self.mamba_d_state,
|
||||
self.mamba_d_head,
|
||||
divide(self.mamba_n_heads, world_size),
|
||||
)
|
||||
conv_dtype = torch.bfloat16
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
|
||||
mamba_layers = self.linear_layer_ids
|
||||
|
||||
return (
|
||||
conv_state_shape,
|
||||
temporal_state_shape,
|
||||
conv_dtype,
|
||||
ssm_dtype,
|
||||
mamba_layers,
|
||||
)
|
||||
@@ -41,6 +41,7 @@ from sglang.srt.configs import (
|
||||
DotsOCRConfig,
|
||||
DotsVLMConfig,
|
||||
ExaoneConfig,
|
||||
FalconH1Config,
|
||||
KimiVLConfig,
|
||||
LongcatFlashConfig,
|
||||
MultiModalityConfig,
|
||||
@@ -62,6 +63,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
Step3VLConfig.model_type: Step3VLConfig,
|
||||
LongcatFlashConfig.model_type: LongcatFlashConfig,
|
||||
Qwen3NextConfig.model_type: Qwen3NextConfig,
|
||||
FalconH1Config.model_type: FalconH1Config,
|
||||
DotsVLMConfig.model_type: DotsVLMConfig,
|
||||
DotsOCRConfig.model_type: DotsOCRConfig,
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
query_start_loc = torch.arange(
|
||||
0, bs + 1, dtype=torch.int32, device=self.device
|
||||
|
||||
@@ -1,6 +1,39 @@
|
||||
from typing import Callable, List, Tuple
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
|
||||
from sglang.srt.layers.attention.mamba.ops import (
|
||||
mamba_chunk_scan_combined,
|
||||
selective_state_update,
|
||||
)
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
composed_weight_loader,
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
|
||||
|
||||
@@ -62,3 +95,535 @@ def mamba_v2_sharded_weight_loader(
|
||||
loaded_boundary += full_dim - extra
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (
|
||||
self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
)
|
||||
|
||||
|
||||
class MambaMixer2(torch.nn.Module):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
the `contextualized_states`. A, D are input independent
|
||||
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
|
||||
for why A isn't selective) ∆, B, C are input-dependent
|
||||
(this is a key difference between Mamba and the linear time
|
||||
invariant S4, and is why Mamba is called
|
||||
**selective** state spaces)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
chunk_size: int,
|
||||
layer_id: int,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
# cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# For TP, the sharding plan is as follows:
|
||||
# - for the conv modules, since
|
||||
# conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
|
||||
# we shard intermediate_size and n_groups
|
||||
# - since intermediate_size = n_heads * head_dim, sharding on
|
||||
# intermediate_size is achieved by sharding on n_heads.
|
||||
# - IF, world_size divides groups, then sharding
|
||||
# (n_groups / world_size, n_heads / world_size)
|
||||
# also maintains the invariant n_heads % n_groups == 0
|
||||
# - HOWEVER IF, world_size DOES NOT divide groups, then we need
|
||||
# to allocate extra space in the shard, such that groups
|
||||
# may be replicated to follow the head shard.
|
||||
# - NOTE: currently for the world size DOES NOT divide groups
|
||||
# case, we only support the case when n_groups == 1
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
assert (
|
||||
num_heads % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide num heads."
|
||||
|
||||
assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
|
||||
"If tensor parallel world size does not divide num_groups, "
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
self.activation = activation
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
self.n_groups = n_groups
|
||||
if n_groups % self.tp_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size
|
||||
)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
if n_groups % self.tp_size != 0:
|
||||
# This is the n_groups == 1 case,
|
||||
# where we need to duplicate groups if TP>1.
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||
# for conv1d.bias, covn1d.weight and in_proj.weight
|
||||
# - need to set these settings, to assign the groups
|
||||
# to the head shards
|
||||
group_shard_settings = (
|
||||
self.groups_ssm_state_size, # expected model size
|
||||
(self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
|
||||
n_groups == 1, # if there was only one group
|
||||
)
|
||||
intermediate_settings = (intermediate_size, 0, False)
|
||||
head_settings = (self.num_heads, 0, False)
|
||||
|
||||
# - the weight already has a "weight_loader" attribute
|
||||
# which set_weight_attrs will raise if we do not
|
||||
# delete before trying to override it
|
||||
# - ditto for the other two weights below
|
||||
delattr(self.conv1d.bias, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.bias,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
delattr(self.conv1d.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.conv1d.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
],
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if quant_config is None:
|
||||
# - quant layers do not have a weight loader
|
||||
delattr(self.in_proj.weight, "weight_loader")
|
||||
set_weight_attrs(
|
||||
self.in_proj.weight,
|
||||
{
|
||||
"weight_loader": mamba_v2_sharded_weight_loader(
|
||||
[
|
||||
intermediate_settings, # for gate
|
||||
intermediate_settings,
|
||||
group_shard_settings,
|
||||
group_shard_settings,
|
||||
head_settings, # for dt
|
||||
],
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
||||
# Can't do this in `weight_loader` since it already exists in
|
||||
# `ColumnParallelLinear` and `MergedColumnParallelLinear`,
|
||||
# and `set_weight_attrs` doesn't allow to override it
|
||||
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
|
||||
|
||||
# - these are TPed by heads to reduce the size of the
|
||||
# temporal shape
|
||||
self.A = nn.Parameter(
|
||||
torch.empty(
|
||||
divide(num_heads, self.tp_size),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
)
|
||||
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
|
||||
self.use_rms_norm = use_rms_norm
|
||||
|
||||
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
|
||||
a_weight_loader = composed_weight_loader(
|
||||
sharded_weight_loader(0), lambda x: -torch.exp(x.float())
|
||||
)
|
||||
set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
|
||||
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
||||
|
||||
self.out_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=use_bias,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.out_proj",
|
||||
reduce_results=False,
|
||||
)
|
||||
|
||||
self.norm = Mixer2RMSNormGated(
|
||||
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# attn_backend_list[-1] gives access to MambaAttnBackend
|
||||
mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
|
||||
attn_metadata = mamba_backend.forward_metadata
|
||||
state_indices_tensor = attn_metadata.mamba_cache_indices
|
||||
chunk_size = self.chunk_size
|
||||
|
||||
conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
|
||||
self.layer_id
|
||||
)
|
||||
|
||||
assert (
|
||||
ssm_state.size(1) == self.ssm_state_size
|
||||
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
|
||||
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
|
||||
chunk_size = self.chunk_size
|
||||
|
||||
# TODO: properly support this
|
||||
prep_initial_states = False
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
|
||||
if mup_vector is not None:
|
||||
projected_states = projected_states * mup_vector
|
||||
|
||||
gate, hidden_states_B_C, dt = torch.split(
|
||||
projected_states,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.conv_dim // self.tp_size,
|
||||
self.num_heads // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
|
||||
# - get hidden_states, B and C after depthwise convolution.
|
||||
split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
|
||||
hidden_states_B_C,
|
||||
[
|
||||
self.intermediate_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
self.groups_ssm_state_size // self.tp_size,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
projected_states.shape[0],
|
||||
(self.num_heads * self.head_dim) // self.tp_size,
|
||||
],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
num_prefill_tokens = forward_batch.extend_num_tokens or 0
|
||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||
cache_indices = attn_metadata.mamba_cache_indices
|
||||
|
||||
x = hidden_states_B_C.transpose(
|
||||
0, 1
|
||||
) # this is the form that causal-conv see
|
||||
hidden_states_B_C = causal_conv1d_fn(
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
).transpose(0, 1)
|
||||
|
||||
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
|
||||
if has_initial_states is not None and prep_initial_states:
|
||||
initial_states = torch.where(
|
||||
has_initial_states[:, None, None, None],
|
||||
ssm_state[state_indices_tensor],
|
||||
0,
|
||||
)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states.view(
|
||||
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
||||
),
|
||||
dt.unsqueeze(0),
|
||||
self.A,
|
||||
B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
cu_seqlens=query_start_loc,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
num_decodes = len(query_start_loc) - 1
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C = causal_conv1d_update(
|
||||
hidden_states_B_C,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor,
|
||||
)
|
||||
|
||||
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A = (
|
||||
self.A[:, None, ...][:, :, None]
|
||||
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B = B.view(-1, n_groups, B.shape[1] // n_groups)
|
||||
C = C.view(-1, n_groups, C.shape[1] // n_groups)
|
||||
hidden_states = hidden_states.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim
|
||||
)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
selective_state_update(
|
||||
ssm_state.permute(0, 3, 2, 1),
|
||||
hidden_states,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor,
|
||||
out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
|
||||
)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
preallocated_ssm_out = preallocated_ssm_out
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate)
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:], _ = self.out_proj(hidden_states)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
return "mamba2"
|
||||
|
||||
81
python/sglang/srt/layers/attention/mamba/mamba_utils.py
Normal file
81
python/sglang/srt/layers/attention/mamba/mamba_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
|
||||
from sglang.srt.distributed.utils import divide
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape,)
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
2
python/sglang/srt/layers/attention/mamba/ops/__init__.py
Normal file
2
python/sglang/srt/layers/attention/mamba/ops/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .mamba_ssm import selective_state_update
|
||||
from .ssd_combined import mamba_chunk_scan_combined
|
||||
172
python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py
Normal file
172
python/sglang/srt/layers/attention/mamba/ops/layernorm_gated.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# Copyright (c) 2024, Tri Dao.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/60dadf2e0ee730ac337035d5533de10bc26e4847/mamba_ssm/ops/triton/layernorm_gated.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
|
||||
@triton.jit
|
||||
def _layer_norm_fwd_1pass_kernel(
|
||||
X, # pointer to the input
|
||||
Y, # pointer to the output
|
||||
W, # pointer to the weights
|
||||
B, # pointer to the biases
|
||||
Z, # pointer to the other branch
|
||||
Mean, # pointer to the mean
|
||||
Rstd, # pointer to the 1/std
|
||||
stride_x_row: tl.int64,
|
||||
stride_y_row: tl.int64,
|
||||
stride_z_row: tl.int64,
|
||||
M: tl.int64, # number of rows in X
|
||||
N: tl.int64, # number of columns in X
|
||||
eps, # epsilon to avoid division by zero
|
||||
BLOCK_N: tl.constexpr,
|
||||
HAS_BIAS: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
NORM_BEFORE_GATE: tl.constexpr,
|
||||
IS_RMS_NORM: tl.constexpr,
|
||||
):
|
||||
# Map the program id to the row of X and Y it should compute.
|
||||
row = tl.program_id(0)
|
||||
group = tl.program_id(1)
|
||||
X += row * stride_x_row + group * N
|
||||
Y += row * stride_y_row + group * N
|
||||
if HAS_Z:
|
||||
Z += row * stride_z_row + group * N
|
||||
if not IS_RMS_NORM:
|
||||
Mean += group * M
|
||||
Rstd += group * M
|
||||
W += group * N
|
||||
if HAS_BIAS:
|
||||
B += group * N
|
||||
# Compute mean and variance
|
||||
cols = tl.arange(0, BLOCK_N)
|
||||
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
|
||||
if HAS_Z and not NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
|
||||
x *= z * tl.sigmoid(z)
|
||||
if not IS_RMS_NORM:
|
||||
mean = tl.sum(x, axis=0) / N
|
||||
tl.store(Mean + row, mean)
|
||||
xbar = tl.where(cols < N, x - mean, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
else:
|
||||
xbar = tl.where(cols < N, x, 0.0)
|
||||
var = tl.sum(xbar * xbar, axis=0) / N
|
||||
rstd = 1 / tl.sqrt(var + eps)
|
||||
tl.store(Rstd + row, rstd)
|
||||
# Normalize and apply linear transformation
|
||||
mask = cols < N
|
||||
w = tl.load(W + cols, mask=mask).to(tl.float32)
|
||||
if HAS_BIAS:
|
||||
b = tl.load(B + cols, mask=mask).to(tl.float32)
|
||||
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
|
||||
y = x_hat * w + b if HAS_BIAS else x_hat * w
|
||||
if HAS_Z and NORM_BEFORE_GATE:
|
||||
z = tl.load(Z + cols, mask=mask).to(tl.float32)
|
||||
y *= z * tl.sigmoid(z)
|
||||
# Write output
|
||||
tl.store(Y + cols, y, mask=mask)
|
||||
|
||||
|
||||
def _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=None,
|
||||
out=None,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
M, N = x.shape
|
||||
if group_size is None:
|
||||
group_size = N
|
||||
assert N % group_size == 0
|
||||
ngroups = N // group_size
|
||||
assert x.stride(-1) == 1
|
||||
if z is not None:
|
||||
assert z.stride(-1) == 1
|
||||
assert z.shape == (M, N)
|
||||
assert weight.shape == (N,)
|
||||
assert weight.stride(-1) == 1
|
||||
if bias is not None:
|
||||
assert bias.stride(-1) == 1
|
||||
assert bias.shape == (N,)
|
||||
# allocate output
|
||||
if out is not None:
|
||||
assert out.shape == x.shape
|
||||
else:
|
||||
out = torch.empty_like(x)
|
||||
assert out.stride(-1) == 1
|
||||
mean = (
|
||||
torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
if not is_rms_norm
|
||||
else None
|
||||
)
|
||||
rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
|
||||
# Less than 64KB per feature: enqueue fused kernel
|
||||
MAX_FUSED_SIZE = 65536 // x.element_size()
|
||||
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
|
||||
if group_size > BLOCK_N:
|
||||
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
|
||||
# heuristics for number of warps
|
||||
num_warps = min(max(BLOCK_N // 256, 1), 8)
|
||||
grid = (M, ngroups)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_layer_norm_fwd_1pass_kernel[grid](
|
||||
x,
|
||||
out,
|
||||
weight,
|
||||
bias,
|
||||
z,
|
||||
mean,
|
||||
rstd,
|
||||
x.stride(0),
|
||||
out.stride(0),
|
||||
z.stride(0) if z is not None else 0,
|
||||
M,
|
||||
group_size,
|
||||
eps,
|
||||
BLOCK_N=BLOCK_N,
|
||||
NORM_BEFORE_GATE=norm_before_gate,
|
||||
IS_RMS_NORM=is_rms_norm,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, _, _ = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
return y.reshape(x_shape_og)
|
||||
442
python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py
Normal file
442
python/sglang/srt/layers/attention/mamba/ops/mamba_ssm.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/mamba_ssm.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/selective_state_update.py
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
from sglang.srt import _custom_ops as ops
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0")
|
||||
|
||||
if TRITON3:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt)
|
||||
return dt
|
||||
|
||||
else:
|
||||
|
||||
@triton.jit
|
||||
def softplus(dt):
|
||||
dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt)
|
||||
return dt
|
||||
|
||||
|
||||
@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||
@triton.heuristics(
|
||||
{
|
||||
"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"]
|
||||
is not None
|
||||
}
|
||||
)
|
||||
@triton.heuristics(
|
||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}
|
||||
)
|
||||
@triton.jit
|
||||
def _selective_scan_update_kernel(
|
||||
# Pointers to matrices
|
||||
state_ptr,
|
||||
x_ptr,
|
||||
dt_ptr,
|
||||
dt_bias_ptr,
|
||||
A_ptr,
|
||||
B_ptr,
|
||||
C_ptr,
|
||||
D_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
state_batch_indices_ptr,
|
||||
pad_slot_id,
|
||||
# Matrix dimensions
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_state_batch,
|
||||
stride_state_head,
|
||||
stride_state_dim,
|
||||
stride_state_dstate,
|
||||
stride_x_batch,
|
||||
stride_x_head,
|
||||
stride_x_dim,
|
||||
stride_dt_batch,
|
||||
stride_dt_head,
|
||||
stride_dt_dim,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_bias_dim,
|
||||
stride_A_head,
|
||||
stride_A_dim,
|
||||
stride_A_dstate,
|
||||
stride_B_batch,
|
||||
stride_B_group,
|
||||
stride_B_dstate,
|
||||
stride_C_batch,
|
||||
stride_C_group,
|
||||
stride_C_dstate,
|
||||
stride_D_head,
|
||||
stride_D_dim,
|
||||
stride_z_batch,
|
||||
stride_z_head,
|
||||
stride_z_dim,
|
||||
stride_out_batch,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
TIE_HDIM: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
):
|
||||
pid_m = tl.program_id(axis=0)
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
|
||||
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||
# is the same as the batch id.
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
state_batch_indices_ptr += pid_b
|
||||
state_batch_idx = tl.load(state_batch_indices_ptr).to(tl.int64)
|
||||
state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head
|
||||
else:
|
||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||
|
||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptr += pid_h * stride_dt_bias_head
|
||||
A_ptr += pid_h * stride_A_head
|
||||
B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group
|
||||
C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group
|
||||
if HAS_Z:
|
||||
z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = tl.arange(0, BLOCK_SIZE_DSTATE)
|
||||
state_ptrs = state_ptr + (
|
||||
offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate
|
||||
)
|
||||
x_ptrs = x_ptr + offs_m * stride_x_dim
|
||||
dt_ptrs = dt_ptr + offs_m * stride_dt_dim
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim
|
||||
if HAS_D:
|
||||
D_ptr += pid_h * stride_D_head
|
||||
A_ptrs = A_ptr + (
|
||||
offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate
|
||||
)
|
||||
B_ptrs = B_ptr + offs_n * stride_B_dstate
|
||||
C_ptrs = C_ptr + offs_n * stride_C_dstate
|
||||
if HAS_D:
|
||||
D_ptrs = D_ptr + offs_m * stride_D_dim
|
||||
if HAS_Z:
|
||||
z_ptrs = z_ptr + offs_m * stride_z_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
state = tl.load(state_ptrs, mask=mask, other=0.0)
|
||||
|
||||
x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if not TIE_HDIM:
|
||||
dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(
|
||||
A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0
|
||||
).to(tl.float32)
|
||||
dA = tl.exp(A * dt[:, None])
|
||||
else:
|
||||
dt = tl.load(dt_ptr).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt += tl.load(dt_bias_ptr).to(tl.float32)
|
||||
if DT_SOFTPLUS:
|
||||
dt = softplus(dt)
|
||||
A = tl.load(A_ptr).to(tl.float32)
|
||||
dA = tl.exp(A * dt) # scalar, not a matrix
|
||||
|
||||
B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32)
|
||||
if HAS_D:
|
||||
D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
if HAS_Z:
|
||||
z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt
|
||||
state = state * dA + dB * x[:, None]
|
||||
|
||||
mask = (offs_m[:, None] < dim) & (offs_n[None, :] < dstate)
|
||||
if HAS_STATE_BATCH_INDICES:
|
||||
mask &= state_batch_idx != pad_slot_id
|
||||
tl.store(state_ptrs, state, mask=mask)
|
||||
out = tl.sum(state * C[None, :], axis=1)
|
||||
if HAS_D:
|
||||
out += x * D
|
||||
if HAS_Z:
|
||||
out *= z * tl.sigmoid(z)
|
||||
tl.store(out_ptrs, out, mask=offs_m < dim)
|
||||
|
||||
|
||||
def selective_state_update(
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
dt_softplus=False,
|
||||
state_batch_indices=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
out=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
x: (batch, dim) or (batch, nheads, dim)
|
||||
dt: (batch, dim) or (batch, nheads, dim)
|
||||
A: (dim, dstate) or (nheads, dim, dstate)
|
||||
B: (batch, dstate) or (batch, ngroups, dstate)
|
||||
C: (batch, dstate) or (batch, ngroups, dstate)
|
||||
D: (dim,) or (nheads, dim)
|
||||
z: (batch, dim) or (batch, nheads, dim)
|
||||
dt_bias: (dim,) or (nheads, dim)
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padded
|
||||
entries that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
|
||||
in this case, the kernel will not process entries at
|
||||
indices 0 and 3
|
||||
out: Preallocated ssm output tensor. Assume same shape as x.
|
||||
In-place updated.
|
||||
"""
|
||||
if state.dim() == 3:
|
||||
state = state.unsqueeze(1)
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
if dt.dim() == 2:
|
||||
dt = dt.unsqueeze(1)
|
||||
if A.dim() == 2:
|
||||
A = A.unsqueeze(0)
|
||||
if B.dim() == 2:
|
||||
B = B.unsqueeze(1)
|
||||
if C.dim() == 2:
|
||||
C = C.unsqueeze(1)
|
||||
if D is not None and D.dim() == 1:
|
||||
D = D.unsqueeze(0)
|
||||
if z is not None and z.dim() == 2:
|
||||
z = z.unsqueeze(1)
|
||||
if dt_bias is not None and dt_bias.dim() == 1:
|
||||
dt_bias = dt_bias.unsqueeze(0)
|
||||
if out.dim() == 2:
|
||||
out = out.unsqueeze(1)
|
||||
|
||||
_, nheads, dim, dstate = state.shape
|
||||
batch = x.shape[0]
|
||||
|
||||
assert x.shape == (batch, nheads, dim)
|
||||
assert dt.shape == x.shape
|
||||
assert A.shape == (nheads, dim, dstate)
|
||||
ngroups = B.shape[1]
|
||||
assert nheads % ngroups == 0, "nheads must be divisible by ngroups"
|
||||
assert B.shape == (batch, ngroups, dstate)
|
||||
assert C.shape == B.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, dim)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads, dim)
|
||||
if state_batch_indices is not None:
|
||||
assert state_batch_indices.shape == (batch,)
|
||||
assert out.shape == x.shape
|
||||
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE_M"]), batch, nheads)
|
||||
z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0)
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = (
|
||||
(32, 4)
|
||||
if dstate <= 16
|
||||
else (
|
||||
(16, 4)
|
||||
if dstate <= 32
|
||||
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
||||
)
|
||||
)
|
||||
tie_hdim = (
|
||||
A.stride(-1) == 0
|
||||
and A.stride(-2) == 0
|
||||
and dt.stride(-1) == 0
|
||||
and dt_bias.stride(-1) == 0
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_selective_scan_update_kernel[grid](
|
||||
state,
|
||||
x,
|
||||
dt,
|
||||
dt_bias,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
out,
|
||||
state_batch_indices,
|
||||
pad_slot_id,
|
||||
batch,
|
||||
nheads,
|
||||
dim,
|
||||
dstate,
|
||||
nheads // ngroups,
|
||||
state.stride(0),
|
||||
state.stride(1),
|
||||
state.stride(2),
|
||||
state.stride(3),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
*(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else 0,
|
||||
A.stride(0),
|
||||
A.stride(1),
|
||||
A.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
*(D.stride(0), D.stride(1)) if D is not None else 0,
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
dt_softplus,
|
||||
tie_hdim,
|
||||
BLOCK_SIZE_M,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
|
||||
|
||||
def selective_scan_fn(
|
||||
u,
|
||||
ssm_states,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D=None,
|
||||
z=None,
|
||||
delta_bias=None,
|
||||
delta_softplus=False,
|
||||
query_start_loc=None,
|
||||
cache_indices=None,
|
||||
has_initial_state=None,
|
||||
pad_slot_id=PAD_SLOT_ID,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
u: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
applies changes in place.
|
||||
ssm_states: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||
applies changes in place.
|
||||
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
A: (dim, dstate)
|
||||
B: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
C: (ngroups, dstate, total_length) for varlen or
|
||||
(batch,ngroups,dstate,seqlen)
|
||||
D: (dim,)
|
||||
z: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
dt_bias: (dim,) or (dim)
|
||||
query_start_loc: (batch + 1) int32
|
||||
The cumulative sequence lengths of the sequences in
|
||||
the batch, used to index into sequence. prepended with 0.
|
||||
for example: query_start_loc = torch.Tensor([0,10,16,17]),
|
||||
x.shape=(dim,17)
|
||||
cache_indices: (batch) int32
|
||||
A tensor with each cell is a correspondent
|
||||
input and output ssm_state index
|
||||
has_initial_state: (batch) bool
|
||||
A tensor populated with ones and zeros,
|
||||
indicate if the ssm_state at the corresponding index should be
|
||||
used as initial state. Not providing argument assumes
|
||||
there's no initial state
|
||||
pad_slot_id: int
|
||||
if cache_indices is passed, lets the kernel identify padding entries
|
||||
that will not be processed,
|
||||
for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
|
||||
in this case, the kernel will not process entries at indices 0 and 3
|
||||
returns
|
||||
output: (dim, total_length) for varlen or (batch, dim, seqlen)
|
||||
supports inplace replacement
|
||||
"""
|
||||
if u.stride(-1) != 1:
|
||||
u = u.contiguous()
|
||||
if delta.stride(-1) != 1:
|
||||
delta = delta.contiguous()
|
||||
if D is not None:
|
||||
D = D.contiguous()
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if z is not None and z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
if B.dim() == 3 and query_start_loc is None:
|
||||
B = B.unsqueeze(1)
|
||||
if B.dim() == 2 and query_start_loc is not None:
|
||||
B = B.unsqueeze(0)
|
||||
if C.dim() == 3 and query_start_loc is None:
|
||||
C = C.unsqueeze(1)
|
||||
if C.dim() == 2 and query_start_loc is not None:
|
||||
C = C.unsqueeze(0)
|
||||
|
||||
ops.selective_scan_fwd(
|
||||
u,
|
||||
delta,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
z,
|
||||
delta_bias,
|
||||
delta_softplus,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state,
|
||||
ssm_states,
|
||||
pad_slot_id,
|
||||
)
|
||||
|
||||
if z is None:
|
||||
return delta # output written inplace to delta
|
||||
else:
|
||||
return z # output written inplace to z
|
||||
264
python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py
Normal file
264
python/sglang/srt/layers/attention/mamba/ops/ssd_bmm.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_bmm.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["chunk_size", "K", "IS_CAUSAL"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
out_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
seqlen,
|
||||
chunk_size,
|
||||
K,
|
||||
ngroups,
|
||||
stride_a_batch,
|
||||
stride_a_seqlen,
|
||||
stride_a_head,
|
||||
stride_ak,
|
||||
stride_b_batch,
|
||||
stride_b_seqlen,
|
||||
stride_b_head,
|
||||
stride_bk,
|
||||
stride_out_batch,
|
||||
stride_out_chunk,
|
||||
stride_out_head,
|
||||
stride_outm,
|
||||
stride_outn,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
dot_dtype: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr = 16,
|
||||
BLOCK_SIZE_N: tl.constexpr = 16,
|
||||
BLOCK_SIZE_K: tl.constexpr = 16,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_ch = tl.program_id(axis=2).to(tl.int64)
|
||||
pid_c = pid_ch // ngroups
|
||||
pid_h = pid_ch - pid_c * ngroups
|
||||
num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
if IS_CAUSAL:
|
||||
if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:
|
||||
return
|
||||
a_ptr += (
|
||||
pid_b * stride_a_batch
|
||||
+ pid_c * chunk_size * stride_a_seqlen
|
||||
+ pid_h * stride_a_head
|
||||
)
|
||||
b_ptr += (
|
||||
pid_b * stride_b_batch
|
||||
+ pid_c * chunk_size * stride_b_seqlen
|
||||
+ pid_h * stride_b_head
|
||||
)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += (
|
||||
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
)
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)
|
||||
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
||||
a = tl.load(
|
||||
a_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k[None, :] < K - k * BLOCK_SIZE_K),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K)
|
||||
& (offs_n[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(dot_dtype)
|
||||
acc += tl.dot(a, b)
|
||||
a_ptrs += BLOCK_SIZE_K * stride_ak
|
||||
b_ptrs += BLOCK_SIZE_K * stride_bk
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
if HAS_SEQ_IDX:
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1,
|
||||
)
|
||||
seq_idx_n = tl.load(
|
||||
seq_idx_ptr + offs_n * stride_seq_idx_seqlen,
|
||||
mask=offs_n < chunk_size_limit,
|
||||
other=-2,
|
||||
)
|
||||
acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)
|
||||
out = acc.to(out_ptr.dtype.element_ty)
|
||||
|
||||
out_ptr += (
|
||||
pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head
|
||||
)
|
||||
out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
out,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
def _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):
|
||||
"""
|
||||
Argument:
|
||||
a: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
b: (batch, seqlen, k) or (batch, seqlen, ngroups, k)
|
||||
seq_idx: (batch, seqlen) or None. out[i, j] for seq_idx[i] != seq_idx[j] will be zeroed out.
|
||||
causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are
|
||||
guaranteed to be correct.
|
||||
Return:
|
||||
out: (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
"""
|
||||
# Check constraints.
|
||||
has_groups = a.dim() == 4
|
||||
if not has_groups:
|
||||
batch, seqlen, k = a.shape
|
||||
else:
|
||||
batch, seqlen, ngroups, k = a.shape
|
||||
assert b.shape == a.shape
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if a.stride(-1) != 1 and a.stride(1) != 1:
|
||||
a = a.contiguous()
|
||||
if b.stride(-1) != 1 and b.stride(1) != 1:
|
||||
b = b.contiguous()
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
# Allocates output.
|
||||
out_dtype = a.dtype if output_dtype is None else output_dtype
|
||||
out = torch.empty(
|
||||
(
|
||||
(batch, nchunks, chunk_size, chunk_size)
|
||||
if not has_groups
|
||||
else (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
),
|
||||
device=a.device,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
dot_dtype = (
|
||||
tl.bfloat16
|
||||
if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16
|
||||
else (
|
||||
tl.float16
|
||||
if a.dtype == torch.float16 or b.dtype == torch.float16
|
||||
else tl.float32
|
||||
)
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]),
|
||||
batch,
|
||||
nchunks if not has_groups else nchunks * ngroups,
|
||||
)
|
||||
with torch.cuda.device(a.device.index):
|
||||
_bmm_chunk_fwd_kernel[grid](
|
||||
a,
|
||||
b,
|
||||
out,
|
||||
seq_idx,
|
||||
seqlen,
|
||||
chunk_size,
|
||||
k,
|
||||
ngroups if has_groups else 1,
|
||||
a.stride(0),
|
||||
a.stride(1),
|
||||
0 if not has_groups else a.stride(2),
|
||||
a.stride(-1),
|
||||
b.stride(0),
|
||||
b.stride(1),
|
||||
0 if not has_groups else b.stride(2),
|
||||
b.stride(-1),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
0 if not has_groups else out.stride(2),
|
||||
out.stride(-2),
|
||||
out.stride(-1),
|
||||
*(
|
||||
(seq_idx.stride(0), seq_idx.stride(1))
|
||||
if seq_idx is not None
|
||||
else (0, 0)
|
||||
),
|
||||
causal,
|
||||
dot_dtype,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
)
|
||||
return out
|
||||
622
python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py
Normal file
622
python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py
Normal file
@@ -0,0 +1,622 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py
|
||||
|
||||
# ruff: noqa: E501,SIM102
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from packaging import version
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
cb_ptr,
|
||||
x_ptr,
|
||||
z_ptr,
|
||||
out_ptr,
|
||||
out_x_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
C_ptr,
|
||||
states_ptr,
|
||||
D_ptr,
|
||||
initstates_ptr,
|
||||
chunk_indices_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
chunk_size,
|
||||
hdim,
|
||||
dstate,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_cb_batch,
|
||||
stride_cb_chunk,
|
||||
stride_cb_head,
|
||||
stride_cb_csize_m,
|
||||
stride_cb_csize_k,
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_head,
|
||||
stride_x_hdim,
|
||||
stride_z_batch,
|
||||
stride_z_seqlen,
|
||||
stride_z_head,
|
||||
stride_z_hdim,
|
||||
stride_out_batch,
|
||||
stride_out_seqlen,
|
||||
stride_out_head,
|
||||
stride_out_hdim,
|
||||
stride_dt_batch,
|
||||
stride_dt_chunk,
|
||||
stride_dt_head,
|
||||
stride_dt_csize,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
stride_C_batch,
|
||||
stride_C_seqlen,
|
||||
stride_C_head,
|
||||
stride_C_dstate,
|
||||
stride_states_batch,
|
||||
stride_states_chunk,
|
||||
stride_states_head,
|
||||
stride_states_hdim,
|
||||
stride_states_dstate,
|
||||
stride_init_states_batch,
|
||||
stride_init_states_head,
|
||||
stride_init_states_hdim,
|
||||
stride_init_states_dstate,
|
||||
stride_D_head,
|
||||
# Meta-parameters
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
HAS_D: tl.constexpr,
|
||||
D_HAS_HDIM: tl.constexpr,
|
||||
HAS_Z: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||
IS_TRITON_22: tl.constexpr,
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr = 16,
|
||||
BLOCK_SIZE_N: tl.constexpr = 16,
|
||||
BLOCK_SIZE_K: tl.constexpr = 16,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
if not HAS_INITSTATES:
|
||||
c_idx = pid_c
|
||||
c_off = 0
|
||||
else:
|
||||
c_idx = tl.load(chunk_indices_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
c_off = tl.load(chunk_offsets_ptr + pid_c, mask=pid_c > -1, other=0)
|
||||
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
cb_ptr += (
|
||||
pid_b * stride_cb_batch
|
||||
+ c_idx * stride_cb_chunk
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_cb_head
|
||||
)
|
||||
x_ptr += (
|
||||
pid_b * stride_x_batch
|
||||
+ c_idx * chunk_size * stride_x_seqlen
|
||||
+ pid_h * stride_x_head
|
||||
)
|
||||
dt_ptr += pid_b * stride_dt_batch + c_idx * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += (
|
||||
pid_b * stride_dA_cs_batch
|
||||
+ c_idx * stride_dA_cs_chunk
|
||||
+ pid_h * stride_dA_cs_head
|
||||
)
|
||||
C_ptr += (
|
||||
pid_b * stride_C_batch
|
||||
+ c_idx * chunk_size * stride_C_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_C_head
|
||||
)
|
||||
|
||||
# M-block offsets and prev states
|
||||
# - logic in next block may override these if there is an active offset
|
||||
offs_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
prev_states_ptr = (
|
||||
states_ptr
|
||||
+ pid_b * stride_states_batch
|
||||
+ c_idx * stride_states_chunk
|
||||
+ pid_h * stride_states_head
|
||||
)
|
||||
prev_states_hdim = stride_states_hdim
|
||||
prev_states_dstate = stride_states_dstate
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - c_idx * chunk_size)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += (
|
||||
pid_b * stride_seq_idx_batch + c_idx * chunk_size * stride_seq_idx_seqlen
|
||||
)
|
||||
|
||||
# - we only need seq_idx_prev to be aligned to chunk boundary
|
||||
seq_idx_prev = tl.load(
|
||||
seq_idx_ptr - stride_seq_idx_seqlen, mask=c_idx >= 1, other=0
|
||||
)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states, we only need seq_idx_m to point
|
||||
# what is the current seq_idx
|
||||
|
||||
# get current seq idx
|
||||
if (pid_m * BLOCK_SIZE_M + c_off) < chunk_size_limit:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr
|
||||
+ (pid_m * BLOCK_SIZE_M + c_off) * stride_seq_idx_seqlen,
|
||||
)
|
||||
|
||||
# - recall that in ssd_state_passing, for the case c_off == 0
|
||||
# i.e., the very first sequence, we made states_ptr hold its initial state
|
||||
# so this edge case is taken care of
|
||||
if (
|
||||
(c_off == 0)
|
||||
and (
|
||||
seq_idx_prev != seq_idx_m
|
||||
) # if a seq is changed exactly on boundary
|
||||
or (c_off > 0) # implies a new example (pseudo chunk)
|
||||
):
|
||||
|
||||
# - replace prev_states_ptr with init_states
|
||||
prev_states_ptr = (
|
||||
initstates_ptr
|
||||
+ seq_idx_m * stride_init_states_batch
|
||||
+ pid_h * stride_init_states_head
|
||||
)
|
||||
prev_states_hdim = stride_init_states_hdim # override strides
|
||||
prev_states_dstate = stride_init_states_dstate
|
||||
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
dA_cs_m = tl.load(
|
||||
dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0
|
||||
).to(tl.float32)
|
||||
|
||||
# - handle chunk state limit
|
||||
if HAS_INITSTATES:
|
||||
|
||||
# have to split this if otherwise compilation will have problems
|
||||
dA_cs_m_boundary = 0.0
|
||||
|
||||
# get the c_idx for the next (logica) chunk
|
||||
c_idx_n = tl.load(
|
||||
chunk_indices_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=-1, # to trigger different chunk
|
||||
)
|
||||
|
||||
# - there are things to consider
|
||||
# A. if c_off > 0 then we need to move the dA_cs boundary to ensure correct
|
||||
# contribution of past states
|
||||
# B. if c_off_n < chunk_size_limit, then we need to adjust this so as not to
|
||||
# encroach into the next sequence, where c_off_n is the offset of the next
|
||||
# (logical) chunk.
|
||||
# An equivalent check for B is c_idx == c_idx_n, where there is repetition in
|
||||
# (logical) chunk indices.
|
||||
|
||||
if (c_idx == c_idx_n) or c_off > 0:
|
||||
|
||||
# get the next offset
|
||||
c_off_n = tl.load(
|
||||
chunk_offsets_ptr + (pid_c + 1),
|
||||
mask=pid_c > -1 and (pid_c + 1) < chunk_meta_num,
|
||||
other=chunk_size,
|
||||
)
|
||||
|
||||
# in this case, adjust down the chunk_size_limit
|
||||
if c_idx == c_idx_n:
|
||||
chunk_size_limit = min(c_off_n, chunk_size_limit)
|
||||
|
||||
# get the cs at the offset boundary
|
||||
# - c_off == 0 is a passthrough
|
||||
# - We need dA_cs at the boundary, defined by c_off - no need
|
||||
# to increase pointer by pid_m (it is a constant offset,
|
||||
# i.e. the same for all blocks)
|
||||
dA_cs_m_boundary = tl.load(
|
||||
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
# - handle seq idx when HAS_INITSTATES==False
|
||||
if not HAS_INITSTATES:
|
||||
seq_idx_m = tl.load(
|
||||
seq_idx_ptr + offs_m * stride_seq_idx_seqlen,
|
||||
mask=offs_m < chunk_size_limit,
|
||||
other=-1,
|
||||
)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
|
||||
# Without the if (pid_c > -1), with Triton 2.1.0, I get
|
||||
# Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed.
|
||||
# With Triton 2.2.0, this works
|
||||
if IS_TRITON_22 or c_idx > -1:
|
||||
# Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128
|
||||
offs_k_dstate = tl.arange(
|
||||
0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K
|
||||
)
|
||||
C_ptrs = C_ptr + (
|
||||
offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate
|
||||
)
|
||||
|
||||
prev_states_ptrs = prev_states_ptr + (
|
||||
offs_n[None, :] * prev_states_hdim
|
||||
+ offs_k_dstate[:, None] * prev_states_dstate
|
||||
)
|
||||
if HAS_SEQ_IDX:
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
# - this is for continuous batching where there is no init states
|
||||
scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0)
|
||||
else:
|
||||
# - if there is initstates, we will rely on prev_states, no zeroing
|
||||
# required.
|
||||
scale_m = tl.exp(dA_cs_m - dA_cs_m_boundary)
|
||||
else:
|
||||
scale_m = tl.exp(dA_cs_m)
|
||||
if BLOCK_SIZE_DSTATE <= 128:
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc = tl.dot(C, prev_states) * scale_m[:, None]
|
||||
else:
|
||||
for k in range(0, dstate, BLOCK_SIZE_K):
|
||||
C = tl.load(
|
||||
C_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size_limit)
|
||||
& (offs_k_dstate[None, :] < dstate - k),
|
||||
other=0.0,
|
||||
)
|
||||
# C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty)
|
||||
prev_states = tl.load(
|
||||
prev_states_ptrs,
|
||||
mask=(offs_k_dstate[:, None] < dstate - k)
|
||||
& (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
prev_states = prev_states.to(C_ptr.dtype.element_ty)
|
||||
acc += tl.dot(C, prev_states)
|
||||
C_ptrs += BLOCK_SIZE_K
|
||||
prev_states_ptrs += BLOCK_SIZE_K
|
||||
acc *= scale_m[:, None]
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K) + c_off
|
||||
cb_ptrs = cb_ptr + (
|
||||
offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k
|
||||
)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
K_MAX = (
|
||||
chunk_size_limit
|
||||
if not IS_CAUSAL
|
||||
else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit)
|
||||
)
|
||||
for k in range(0, K_MAX, BLOCK_SIZE_K):
|
||||
cb = tl.load(
|
||||
cb_ptrs,
|
||||
mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
# If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j].
|
||||
# So we don't need masking wrt seq_idx here.
|
||||
cb *= tl.exp(dA_cs_m[:, None] - dA_cs_k[None, :])
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32)
|
||||
cb *= dt_k
|
||||
if IS_CAUSAL:
|
||||
mask = offs_m[:, None] >= k + offs_k[None, :]
|
||||
cb = tl.where(mask, cb, 0.0)
|
||||
cb = cb.to(x_ptr.dtype.element_ty)
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
)
|
||||
acc += tl.dot(cb, x)
|
||||
cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
offs_out_m = pid_m * BLOCK_SIZE_M + c_off + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
|
||||
if HAS_D:
|
||||
if D_HAS_HDIM:
|
||||
D = tl.load(
|
||||
D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0
|
||||
).to(tl.float32)
|
||||
else:
|
||||
D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32)
|
||||
x_residual = tl.load(
|
||||
x_ptr
|
||||
+ (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim),
|
||||
mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc += x_residual * D
|
||||
|
||||
if HAS_Z:
|
||||
out_x_ptr += (
|
||||
pid_b * stride_out_batch
|
||||
+ c_idx * chunk_size * stride_out_seqlen
|
||||
+ pid_h * stride_out_head
|
||||
)
|
||||
out_x_ptrs = out_x_ptr + (
|
||||
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]
|
||||
)
|
||||
tl.store(
|
||||
out_x_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit)
|
||||
& (offs_out_n[None, :] < hdim),
|
||||
)
|
||||
|
||||
z_ptr += (
|
||||
pid_b * stride_z_batch
|
||||
+ c_idx * chunk_size * stride_z_seqlen
|
||||
+ pid_h * stride_z_head
|
||||
)
|
||||
z_ptrs = z_ptr + (
|
||||
stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]
|
||||
)
|
||||
z = tl.load(
|
||||
z_ptrs,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit)
|
||||
& (offs_out_n[None, :] < hdim),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
acc *= z * tl.sigmoid(z)
|
||||
|
||||
out_ptr += (
|
||||
pid_b * stride_out_batch
|
||||
+ c_idx * chunk_size * stride_out_seqlen
|
||||
+ pid_h * stride_out_head
|
||||
)
|
||||
out_ptrs = out_ptr + (
|
||||
stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim
|
||||
)
|
||||
tl.store(
|
||||
out_ptrs,
|
||||
acc,
|
||||
mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim),
|
||||
)
|
||||
|
||||
|
||||
def _chunk_scan_fwd(
|
||||
cb,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
D=None,
|
||||
z=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
initial_states=None,
|
||||
out=None,
|
||||
):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = C.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert C.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size)
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
|
||||
if initial_states is not None:
|
||||
# with initial states, we need to take care of how
|
||||
# seq_idx crosses the boundaries
|
||||
assert batch == 1, "chunk scan only supports initial states with batch 1"
|
||||
assert (
|
||||
chunk_indices is not None and chunk_offsets is not None
|
||||
), "chunk_indices and chunk_offsets should have been set"
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
else:
|
||||
chunk_indices, chunk_offsets = None, None
|
||||
|
||||
assert out.shape == x.shape
|
||||
|
||||
if z is not None:
|
||||
out_x = torch.empty_like(x)
|
||||
assert out_x.stride() == out.stride()
|
||||
else:
|
||||
out_x = None
|
||||
|
||||
grid = lambda META: (
|
||||
triton.cdiv(chunk_size, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(headdim, META["BLOCK_SIZE_N"]),
|
||||
batch * nchunks if chunk_offsets is None else len(chunk_offsets),
|
||||
nheads,
|
||||
)
|
||||
z_strides = (
|
||||
(z.stride(0), z.stride(1), z.stride(2), z.stride(3))
|
||||
if z is not None
|
||||
else (0, 0, 0, 0)
|
||||
)
|
||||
_chunk_scan_fwd_kernel[grid](
|
||||
cb,
|
||||
x,
|
||||
z,
|
||||
out,
|
||||
out_x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
C,
|
||||
states,
|
||||
D,
|
||||
initial_states,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
len(chunk_indices) if chunk_indices is not None else 0,
|
||||
chunk_size,
|
||||
headdim,
|
||||
dstate,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads // ngroups,
|
||||
cb.stride(0),
|
||||
cb.stride(1),
|
||||
cb.stride(2),
|
||||
cb.stride(3),
|
||||
cb.stride(4),
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
z_strides[0],
|
||||
z_strides[1],
|
||||
z_strides[2],
|
||||
z_strides[3],
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
out.stride(3),
|
||||
dt.stride(0),
|
||||
dt.stride(2),
|
||||
dt.stride(1),
|
||||
dt.stride(3),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),
|
||||
C.stride(0),
|
||||
C.stride(1),
|
||||
C.stride(2),
|
||||
C.stride(3),
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
states.stride(4),
|
||||
*(
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
),
|
||||
D.stride(0) if D is not None else 0,
|
||||
True,
|
||||
D is not None,
|
||||
D.dim() == 2 if D is not None else True,
|
||||
BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16),
|
||||
HAS_Z=z is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
IS_TRITON_22=TRITON_22,
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return out_x
|
||||
757
python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py
Normal file
757
python/sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py
Normal file
@@ -0,0 +1,757 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_SIZE_H": 2}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 4}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 8}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 16}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 32}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 64}),
|
||||
# ],
|
||||
# key=["chunk_size", "nheads"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
dt_ptr,
|
||||
A_ptr,
|
||||
dt_bias_ptr,
|
||||
dt_out_ptr,
|
||||
dA_cumsum_ptr,
|
||||
# Matrix dimension
|
||||
batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
chunk_size,
|
||||
dt_min,
|
||||
dt_max,
|
||||
# Strides
|
||||
stride_dt_batch,
|
||||
stride_dt_seqlen,
|
||||
stride_dt_head,
|
||||
stride_A_head,
|
||||
stride_dt_bias_head,
|
||||
stride_dt_out_batch,
|
||||
stride_dt_out_chunk,
|
||||
stride_dt_out_head,
|
||||
stride_dt_out_csize,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
# Meta-parameters
|
||||
DT_SOFTPLUS: tl.constexpr,
|
||||
HAS_DT_BIAS: tl.constexpr,
|
||||
BLOCK_SIZE_CHUNK: tl.constexpr,
|
||||
BLOCK_SIZE_H: tl.constexpr = 16,
|
||||
):
|
||||
pid_b = tl.program_id(axis=0)
|
||||
|
||||
# if dt is long, may cause problems, so use 64 bit
|
||||
# https://github.com/triton-lang/triton/issues/1058
|
||||
pid_c = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen
|
||||
dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk
|
||||
dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk
|
||||
|
||||
offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)
|
||||
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)
|
||||
dt_ptrs = dt_ptr + (
|
||||
offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen
|
||||
)
|
||||
A_ptrs = A_ptr + offs_h * stride_A_head
|
||||
dt_out_ptrs = dt_out_ptr + (
|
||||
offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize
|
||||
)
|
||||
dA_cs_ptrs = dA_cumsum_ptr + (
|
||||
offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize
|
||||
)
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
|
||||
dt = tl.load(
|
||||
dt_ptrs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
if HAS_DT_BIAS:
|
||||
dt_bias = tl.load(
|
||||
dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0
|
||||
).to(tl.float32)
|
||||
dt += dt_bias[:, None]
|
||||
if DT_SOFTPLUS:
|
||||
dt = tl.where(dt <= 20.0, softplus(dt), dt)
|
||||
# As of Triton 2.2.0, tl.clamp is not available yet
|
||||
# dt = tl.clamp(dt, dt_min, dt_max)
|
||||
dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)
|
||||
dt = tl.where(
|
||||
(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0
|
||||
)
|
||||
tl.store(
|
||||
dt_out_ptrs,
|
||||
dt,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)
|
||||
dA = dt * A[:, None]
|
||||
dA_cs = tl.cumsum(dA, axis=1)
|
||||
tl.store(
|
||||
dA_cs_ptrs,
|
||||
dA_cs,
|
||||
mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size),
|
||||
)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["hdim", "dstate", "chunk_size"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
states_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
seq_idx_ptr,
|
||||
# Matrix dimensions
|
||||
hdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_x_batch,
|
||||
stride_x_seqlen,
|
||||
stride_x_head,
|
||||
stride_x_hdim,
|
||||
stride_b_batch,
|
||||
stride_b_seqlen,
|
||||
stride_b_head,
|
||||
stride_b_dstate,
|
||||
stride_states_batch,
|
||||
stride_states_chunk,
|
||||
stride_states_head,
|
||||
stride_states_hdim,
|
||||
stride_states_dstate,
|
||||
stride_dt_batch,
|
||||
stride_dt_chunk,
|
||||
stride_dt_head,
|
||||
stride_dt_csize,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
# Meta-parameters
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr = 16,
|
||||
BLOCK_SIZE_N: tl.constexpr = 16,
|
||||
BLOCK_SIZE_K: tl.constexpr = 16,
|
||||
):
|
||||
pid_bc = tl.program_id(axis=1).to(tl.int64)
|
||||
pid_c = pid_bc // batch
|
||||
pid_b = pid_bc - pid_c * batch
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
b_ptr += (
|
||||
pid_b * stride_b_batch
|
||||
+ pid_c * chunk_size * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += (
|
||||
pid_b * stride_x_batch
|
||||
+ pid_c * chunk_size * stride_x_seqlen
|
||||
+ pid_h * stride_x_head
|
||||
)
|
||||
dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += (
|
||||
pid_b * stride_dA_cs_batch
|
||||
+ pid_c * stride_dA_cs_chunk
|
||||
+ pid_h * stride_dA_cs_head
|
||||
)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += (
|
||||
pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen
|
||||
)
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(
|
||||
tl.float32
|
||||
)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen
|
||||
|
||||
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_last = tl.load(
|
||||
seq_idx_ptr + (chunk_size_limit - 1) * stride_seq_idx_seqlen
|
||||
)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_k = tl.load(
|
||||
seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1
|
||||
)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
if not HAS_SEQ_IDX:
|
||||
scale = tl.exp(dA_cs_last - dA_cs_k) * dt_k
|
||||
else:
|
||||
scale = tl.where(
|
||||
seq_idx_k == seq_idx_last, tl.exp(dA_cs_last - dA_cs_k) * dt_k, 0.0
|
||||
)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += (
|
||||
pid_b * stride_states_batch
|
||||
+ pid_c * stride_states_chunk
|
||||
+ pid_h * stride_states_head
|
||||
)
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["hdim", "dstate", "chunk_size"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
x_ptr,
|
||||
b_ptr,
|
||||
dt_ptr,
|
||||
dA_cumsum_ptr,
|
||||
chunk_states_ptr,
|
||||
cu_seqlens_ptr,
|
||||
states_ptr,
|
||||
initstates_ptr,
|
||||
# Matrix dimensions
|
||||
hdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
seqlen,
|
||||
nheads_ngroups_ratio,
|
||||
# Strides
|
||||
stride_x_seqlen,
|
||||
stride_x_head,
|
||||
stride_x_hdim,
|
||||
stride_b_seqlen,
|
||||
stride_b_head,
|
||||
stride_b_dstate,
|
||||
stride_dt_chunk,
|
||||
stride_dt_head,
|
||||
stride_dt_csize,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_chunk_states_chunk,
|
||||
stride_chunk_states_head,
|
||||
stride_chunk_states_hdim,
|
||||
stride_chunk_states_dstate,
|
||||
stride_states_batch,
|
||||
stride_states_head,
|
||||
stride_states_hdim,
|
||||
stride_states_dstate,
|
||||
stride_init_states_batch,
|
||||
stride_init_states_head,
|
||||
stride_init_states_hdim,
|
||||
stride_init_states_dstate,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr = 16,
|
||||
BLOCK_SIZE_N: tl.constexpr = 16,
|
||||
BLOCK_SIZE_K: tl.constexpr = 16,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N)
|
||||
pid_m = tl.program_id(axis=0) // num_pid_n
|
||||
pid_n = tl.program_id(axis=0) % num_pid_n
|
||||
end_idx = tl.load(cu_seqlens_ptr + pid_b + 1)
|
||||
pid_c = (end_idx - 1) // chunk_size
|
||||
b_ptr += (
|
||||
pid_c * chunk_size * stride_b_seqlen
|
||||
+ (pid_h // nheads_ngroups_ratio) * stride_b_head
|
||||
)
|
||||
x_ptr += pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head
|
||||
dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head
|
||||
dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head
|
||||
chunk_states_ptr += (
|
||||
pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head
|
||||
)
|
||||
|
||||
if HAS_INITSTATES:
|
||||
# if there are init states provided, we differentiate between states (which
|
||||
# are boundary conditions at a chunk boundary) and initstates (which are boundary
|
||||
# conditions when a new example in a cont batch starts)
|
||||
initstates_ptr += pid_h * stride_init_states_head
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||
x_ptrs = x_ptr + (
|
||||
offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen
|
||||
)
|
||||
b_ptrs = b_ptr + (
|
||||
offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen
|
||||
)
|
||||
dt_ptrs = dt_ptr + offs_k * stride_dt_csize
|
||||
dA_cs_last = tl.load(
|
||||
dA_cumsum_ptr + (end_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize
|
||||
|
||||
chunk_size_limit = end_idx - pid_c * chunk_size
|
||||
start_idx = tl.load(cu_seqlens_ptr + pid_b)
|
||||
start_idx_cur = tl.maximum(start_idx - pid_c * chunk_size, 0)
|
||||
|
||||
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||
for k in range(0, chunk_size_limit, BLOCK_SIZE_K):
|
||||
x = tl.load(
|
||||
x_ptrs,
|
||||
mask=(offs_m[:, None] < hdim)
|
||||
& (offs_k[None, :] < chunk_size_limit - k)
|
||||
& (offs_k[None, :] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
)
|
||||
b = tl.load(
|
||||
b_ptrs,
|
||||
mask=(offs_k[:, None] < chunk_size_limit - k)
|
||||
& (offs_n[None, :] < dstate)
|
||||
& (offs_k[:, None] >= start_idx_cur - k),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
dA_cs_k = tl.load(
|
||||
dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0
|
||||
).to(tl.float32)
|
||||
dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
scale = tl.where(
|
||||
(offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k),
|
||||
tl.exp(dA_cs_last - dA_cs_k) * dt_k,
|
||||
0.0,
|
||||
)
|
||||
b *= scale[:, None]
|
||||
b = b.to(x_ptr.dtype.element_ty)
|
||||
acc += tl.dot(x, b)
|
||||
x_ptrs += BLOCK_SIZE_K * stride_x_seqlen
|
||||
b_ptrs += BLOCK_SIZE_K * stride_b_seqlen
|
||||
dt_ptrs += BLOCK_SIZE_K * stride_dt_csize
|
||||
dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize
|
||||
|
||||
# If the sequence starts after the last chunk idx, we don't need to add the contribution from the last chunk
|
||||
# If HAS_INITSTATES==True need to consider two possibilities
|
||||
# - if start_idx < pid_c * chunk_size, then we need to take the past_states_ptrs
|
||||
# - if state_idx >= pid * chunk_size, then we need to insert initstates
|
||||
if (start_idx < pid_c * chunk_size) or (HAS_INITSTATES): # first chunk
|
||||
|
||||
dA_cs_boundary = 0.0 # default
|
||||
|
||||
if not HAS_INITSTATES:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
|
||||
# - this seems repetitive, buts its to help the compiler
|
||||
if start_idx < pid_c * chunk_size:
|
||||
past_states_ptrs = chunk_states_ptr + (
|
||||
offs_m[:, None] * stride_chunk_states_hdim
|
||||
+ offs_n[None, :] * stride_chunk_states_dstate
|
||||
)
|
||||
else:
|
||||
past_states_ptrs = initstates_ptr + (
|
||||
pid_b * stride_init_states_batch
|
||||
+ offs_m[:, None] * stride_init_states_hdim
|
||||
+ offs_n[None, :] * stride_init_states_dstate
|
||||
)
|
||||
|
||||
# need to adjust the boundary
|
||||
if start_idx > pid_c * chunk_size:
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cumsum_ptr
|
||||
+ (start_idx - pid_c * chunk_size - 1) * stride_dA_cs_csize
|
||||
).to(tl.float32)
|
||||
|
||||
past_states = tl.load(
|
||||
past_states_ptrs,
|
||||
mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate),
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
|
||||
scale = tl.exp(dA_cs_last - dA_cs_boundary)
|
||||
acc += past_states * scale
|
||||
|
||||
states = acc.to(states_ptr.dtype.element_ty)
|
||||
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||
states_ptrs = states_ptr + (
|
||||
offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate
|
||||
)
|
||||
c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
def _chunk_cumsum_fwd(
|
||||
dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float("inf"))
|
||||
):
|
||||
batch, seqlen, nheads = dt.shape
|
||||
assert A.shape == (nheads,)
|
||||
if dt_bias is not None:
|
||||
assert dt_bias.shape == (nheads,)
|
||||
nchunks = math.ceil(seqlen / chunk_size)
|
||||
dt_out = torch.empty(
|
||||
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
dA_cumsum = torch.empty(
|
||||
batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32
|
||||
)
|
||||
grid_chunk_cs = lambda META: (
|
||||
batch,
|
||||
nchunks,
|
||||
triton.cdiv(nheads, META["BLOCK_SIZE_H"]),
|
||||
)
|
||||
with torch.cuda.device(dt.device.index):
|
||||
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
|
||||
dt,
|
||||
A,
|
||||
dt_bias,
|
||||
dt_out,
|
||||
dA_cumsum,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads,
|
||||
chunk_size,
|
||||
dt_limit[0],
|
||||
dt_limit[1],
|
||||
dt.stride(0),
|
||||
dt.stride(1),
|
||||
dt.stride(2),
|
||||
A.stride(0),
|
||||
dt_bias.stride(0) if dt_bias is not None else 0,
|
||||
dt_out.stride(0),
|
||||
dt_out.stride(2),
|
||||
dt_out.stride(1),
|
||||
dt_out.stride(3),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
dt_softplus,
|
||||
HAS_DT_BIAS=dt_bias is not None,
|
||||
BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),
|
||||
)
|
||||
return dA_cumsum, dt_out
|
||||
|
||||
|
||||
def _chunk_state_fwd(
|
||||
B, x, dt, dA_cumsum, seq_idx=None, states=None, states_in_fp32=True
|
||||
):
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, nchunks, chunk_size = dt.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if states is not None:
|
||||
assert states.shape == (batch, nchunks, nheads, headdim, dstate)
|
||||
else:
|
||||
states_dtype = torch.float32 if states_in_fp32 else B.dtype
|
||||
states = torch.empty(
|
||||
(batch, nchunks, nheads, headdim, dstate),
|
||||
device=x.device,
|
||||
dtype=states_dtype,
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
batch * nchunks,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_fwd_kernel[grid](
|
||||
x,
|
||||
B,
|
||||
states,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
seq_idx,
|
||||
headdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
batch,
|
||||
seqlen,
|
||||
nheads // ngroups,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
x.stride(3),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
B.stride(-1),
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
states.stride(4),
|
||||
dt.stride(0),
|
||||
dt.stride(2),
|
||||
dt.stride(1),
|
||||
dt.stride(3),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*(
|
||||
(seq_idx.stride(0), seq_idx.stride(1))
|
||||
if seq_idx is not None
|
||||
else (0, 0)
|
||||
),
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
)
|
||||
return states
|
||||
|
||||
|
||||
def chunk_state_varlen(
|
||||
B, x, dt, dA_cumsum, cu_seqlens, chunk_states, initial_states=None
|
||||
):
|
||||
total_seqlen, nheads, headdim = x.shape
|
||||
_, nchunks, chunk_size = dt.shape
|
||||
_, ngroups, dstate = B.shape
|
||||
batch = cu_seqlens.shape[0] - 1
|
||||
cu_seqlens = cu_seqlens.contiguous()
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (total_seqlen, ngroups, dstate)
|
||||
assert dt.shape == (nheads, nchunks, chunk_size)
|
||||
assert dA_cumsum.shape == dt.shape
|
||||
assert chunk_states.shape == (nchunks, nheads, headdim, dstate)
|
||||
|
||||
if initial_states is not None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
|
||||
states = torch.empty(
|
||||
batch,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
dtype=chunk_states.dtype,
|
||||
device=chunk_states.device,
|
||||
)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(headdim, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(dstate, META["BLOCK_SIZE_N"]),
|
||||
batch,
|
||||
nheads,
|
||||
)
|
||||
with torch.cuda.device(x.device.index):
|
||||
_chunk_state_varlen_kernel[grid](
|
||||
x,
|
||||
B,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
chunk_states,
|
||||
cu_seqlens,
|
||||
states,
|
||||
initial_states,
|
||||
headdim,
|
||||
dstate,
|
||||
chunk_size,
|
||||
total_seqlen,
|
||||
nheads // ngroups,
|
||||
x.stride(0),
|
||||
x.stride(1),
|
||||
x.stride(2),
|
||||
B.stride(0),
|
||||
B.stride(1),
|
||||
B.stride(2),
|
||||
dt.stride(1),
|
||||
dt.stride(0),
|
||||
dt.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
chunk_states.stride(0),
|
||||
chunk_states.stride(1),
|
||||
chunk_states.stride(2),
|
||||
chunk_states.stride(3),
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
*(
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
initial_states.stride(3),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0, 0)
|
||||
),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
)
|
||||
return states
|
||||
262
python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py
Normal file
262
python/sglang/srt/layers/attention/mamba/ops/ssd_combined.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_combined.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
from .ssd_bmm import _bmm_chunk_fwd
|
||||
from .ssd_chunk_scan import _chunk_scan_fwd
|
||||
from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd, chunk_state_varlen
|
||||
from .ssd_state_passing import _state_passing_fwd
|
||||
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
def is_int_pow_2(n):
|
||||
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def _mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
state_dtype=None,
|
||||
out=None,
|
||||
):
|
||||
assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2"
|
||||
batch, seqlen, nheads, headdim = x.shape
|
||||
_, _, ngroups, dstate = B.shape
|
||||
assert nheads % ngroups == 0
|
||||
assert B.shape == (batch, seqlen, ngroups, dstate)
|
||||
assert dt.shape == (batch, seqlen, nheads)
|
||||
assert A.shape == (nheads,)
|
||||
assert C.shape == B.shape
|
||||
if z is not None:
|
||||
assert z.shape == x.shape
|
||||
if D is not None:
|
||||
assert D.shape == (nheads, headdim) or D.shape == (nheads,)
|
||||
if seq_idx is not None:
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
if B.stride(-1) != 1:
|
||||
B = B.contiguous()
|
||||
if C.stride(-1) != 1:
|
||||
C = C.contiguous()
|
||||
if (
|
||||
x.stride(-1) != 1 and x.stride(1) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
x = x.contiguous()
|
||||
if (
|
||||
z is not None and z.stride(-1) != 1 and z.stride(1) != 1
|
||||
): # Either M or K dimension should be contiguous
|
||||
z = z.contiguous()
|
||||
if D is not None and D.stride(-1) != 1:
|
||||
D = D.contiguous()
|
||||
if initial_states is not None:
|
||||
if cu_seqlens is None:
|
||||
assert initial_states.shape == (batch, nheads, headdim, dstate)
|
||||
else:
|
||||
assert initial_states.shape == (
|
||||
len(cu_seqlens) - 1,
|
||||
nheads,
|
||||
headdim,
|
||||
dstate,
|
||||
)
|
||||
|
||||
# This function executes 5 sub-functions for computing mamba
|
||||
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
|
||||
# which has a minimal implementation to understand the below operations
|
||||
# - as explained by the blog, mamba is a special case of causal attention
|
||||
# - the idea is to chunk the attention matrix and compute each
|
||||
# submatrix separately using different optimizations.
|
||||
# - see the blog and paper for a visualization of the submatrices
|
||||
# which we refer to in the comments below
|
||||
|
||||
# 1. Compute chunked cumsum of A * dt
|
||||
# - here dt may go through a softplus activation
|
||||
dA_cumsum, dt = _chunk_cumsum_fwd(
|
||||
dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit
|
||||
)
|
||||
|
||||
# 2. Compute the state for each intra-chunk
|
||||
# (right term of low-rank factorization of off-diagonal blocks; B terms)
|
||||
states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
|
||||
|
||||
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
|
||||
# (middle term of factorization of off-diag blocks; A terms)
|
||||
# - for handling chunked prefill, this requires i) initial_states
|
||||
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
|
||||
# - When a new seq_idx is detected, we will stop passing the prev_state
|
||||
# and switch accordingly to the init_state corresponding to the new seq_idx.
|
||||
# - We will also make sure that the dA_cumsum is taken only from the start of the
|
||||
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
|
||||
# - this will ensure that states will be updated with the rightmost flushed seq_idx
|
||||
# of the previous chunk. This implies that the first chunk of states is either 0
|
||||
# or equal to init_states of the first example.
|
||||
states, final_states = _state_passing_fwd(
|
||||
rearrange(states, "... p n -> ... (p n)"),
|
||||
dA_cumsum,
|
||||
initial_states=(
|
||||
rearrange(initial_states, "... p n -> ... (p n)")
|
||||
if initial_states is not None
|
||||
else None
|
||||
),
|
||||
seq_idx=seq_idx,
|
||||
chunk_size=chunk_size,
|
||||
out_dtype=state_dtype if state_dtype is not None else C.dtype,
|
||||
is_cont_batched=cu_seqlens is not None,
|
||||
chunk_offsets=chunk_offsets,
|
||||
)
|
||||
states, final_states = (
|
||||
rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]
|
||||
)
|
||||
|
||||
# 4. Compute batched matrix multiply for C_j^T B_i terms
|
||||
CB = _bmm_chunk_fwd(C, B, chunk_size, seq_idx=seq_idx, output_dtype=torch.float32)
|
||||
|
||||
# 5. Scan and compute the diagonal blocks, taking into
|
||||
# account past causal states.
|
||||
# - if initial states are provided, then states information will be
|
||||
# augmented with initial_states.
|
||||
# - to do this properly, we need to account for example changes in
|
||||
# the continuous batch, therefore we introduce pseudo chunks, which is
|
||||
# a chunk that is split up each time an example changes.
|
||||
# - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had
|
||||
# a seq_idx change, in which case we take states information from
|
||||
# init_states.
|
||||
out_x = _chunk_scan_fwd(
|
||||
CB,
|
||||
x,
|
||||
dt,
|
||||
dA_cumsum,
|
||||
C,
|
||||
states,
|
||||
D=D,
|
||||
z=z,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
initial_states=initial_states,
|
||||
out=out,
|
||||
)
|
||||
if cu_seqlens is None:
|
||||
return out_x, dt, dA_cumsum, states, final_states
|
||||
else:
|
||||
assert (
|
||||
batch == 1
|
||||
), "passing cu_seqlens to get the varlen states is only supported if batch dimension is 1"
|
||||
varlen_states = chunk_state_varlen(
|
||||
B.squeeze(0),
|
||||
x.squeeze(0),
|
||||
dt.squeeze(0),
|
||||
dA_cumsum.squeeze(0),
|
||||
cu_seqlens,
|
||||
states.squeeze(0),
|
||||
initial_states=initial_states,
|
||||
)
|
||||
return out_x, dt, dA_cumsum, states, final_states, varlen_states
|
||||
|
||||
|
||||
def mamba_chunk_scan_combined(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=None,
|
||||
z=None,
|
||||
dt_bias=None,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_indices=None,
|
||||
chunk_offsets=None,
|
||||
cu_seqlens=None,
|
||||
dt_softplus=False,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=None,
|
||||
return_final_states=False,
|
||||
return_varlen_states=False,
|
||||
state_dtype=None,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
x: (batch, seqlen, nheads, headdim)
|
||||
dt: (batch, seqlen, nheads)
|
||||
A: (nheads)
|
||||
B: (batch, seqlen, ngroups, dstate)
|
||||
C: (batch, seqlen, ngroups, dstate)
|
||||
chunk_size: int
|
||||
D: (nheads, headdim) or (nheads,)
|
||||
z: (batch, seqlen, nheads, headdim)
|
||||
dt_bias: (nheads,)
|
||||
initial_states: (batch, nheads, headdim, dstate)
|
||||
seq_idx: (batch, seqlen)
|
||||
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
|
||||
dt_softplus: Whether to apply softplus to dt
|
||||
out: Preallocated output tensor
|
||||
state_dtype: The data type of the ssm state
|
||||
"""
|
||||
|
||||
if not return_varlen_states:
|
||||
cu_seqlens = None
|
||||
else:
|
||||
assert (
|
||||
cu_seqlens is not None
|
||||
), "cu_seqlens must be provided if return_varlen_states is True"
|
||||
out_x, dt_out, dA_cumsum, states, final_states, *rest = (
|
||||
_mamba_chunk_scan_combined_fwd(
|
||||
x,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
chunk_size,
|
||||
D=D,
|
||||
z=z,
|
||||
dt_bias=dt_bias,
|
||||
initial_states=initial_states,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
cu_seqlens=cu_seqlens,
|
||||
dt_softplus=dt_softplus,
|
||||
dt_limit=dt_limit,
|
||||
out=out,
|
||||
state_dtype=state_dtype,
|
||||
)
|
||||
)
|
||||
if not return_varlen_states:
|
||||
if not return_final_states:
|
||||
return
|
||||
else:
|
||||
return final_states
|
||||
else:
|
||||
varlen_states = rest[0]
|
||||
return (
|
||||
(varlen_states)
|
||||
if not return_final_states
|
||||
else (final_states, varlen_states)
|
||||
)
|
||||
@@ -0,0 +1,275 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/mamba/ops/ssd_state_passing.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||
# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py
|
||||
|
||||
# ruff: noqa: E501
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_SIZE": 64}),
|
||||
# triton.Config({"BLOCK_SIZE": 128}),
|
||||
# triton.Config({"BLOCK_SIZE": 256}),
|
||||
# triton.Config({"BLOCK_SIZE": 512}),
|
||||
# triton.Config({"BLOCK_SIZE": 1024}),
|
||||
# triton.Config({"BLOCK_SIZE": 2048}),
|
||||
# ],
|
||||
# key=["dim"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
states_ptr,
|
||||
out_ptr,
|
||||
final_states_ptr,
|
||||
dA_cs_ptr,
|
||||
initstates_ptr,
|
||||
seq_idx_ptr,
|
||||
chunk_offsets_ptr,
|
||||
chunk_meta_num,
|
||||
# Matrix dimensions
|
||||
dim,
|
||||
nchunks,
|
||||
seqlen,
|
||||
chunk_size,
|
||||
# Strides
|
||||
stride_states_batch,
|
||||
stride_states_chunk,
|
||||
stride_states_head,
|
||||
stride_states_dim,
|
||||
stride_out_batch,
|
||||
stride_out_chunk,
|
||||
stride_out_head,
|
||||
stride_out_dim,
|
||||
stride_final_states_batch,
|
||||
stride_final_states_head,
|
||||
stride_final_states_dim,
|
||||
stride_dA_cs_batch,
|
||||
stride_dA_cs_chunk,
|
||||
stride_dA_cs_head,
|
||||
stride_dA_cs_csize,
|
||||
stride_initstates_batch,
|
||||
stride_initstates_head,
|
||||
stride_initstates_dim,
|
||||
stride_seq_idx_batch,
|
||||
stride_seq_idx_seqlen,
|
||||
# Meta-parameters
|
||||
HAS_INITSTATES: tl.constexpr,
|
||||
HAS_SEQ_IDX: tl.constexpr,
|
||||
IS_CONT_BATCHED: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr = 16,
|
||||
):
|
||||
pid_b = tl.program_id(axis=1)
|
||||
pid_h = tl.program_id(axis=2)
|
||||
pid_m = tl.program_id(axis=0)
|
||||
states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head
|
||||
dA_cs_ptr += (
|
||||
pid_b * stride_dA_cs_batch
|
||||
+ pid_h * stride_dA_cs_head
|
||||
+ (chunk_size - 1) * stride_dA_cs_csize
|
||||
)
|
||||
out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head
|
||||
final_states_ptr += (
|
||||
pid_b * stride_final_states_batch + pid_h * stride_final_states_head
|
||||
)
|
||||
if HAS_INITSTATES:
|
||||
initstates_ptr += pid_h * stride_initstates_head
|
||||
if not IS_CONT_BATCHED:
|
||||
initstates_ptr += pid_b * stride_initstates_batch
|
||||
|
||||
if HAS_SEQ_IDX:
|
||||
seq_idx_ptr += pid_b * stride_seq_idx_batch
|
||||
|
||||
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
states_ptrs = states_ptr + offs_m * stride_states_dim
|
||||
out_ptrs = out_ptr + offs_m * stride_out_dim
|
||||
final_states_ptrs = final_states_ptr + offs_m * stride_final_states_dim
|
||||
|
||||
# - states will be the past state of the sequence that continues on the current check
|
||||
if not HAS_INITSTATES:
|
||||
states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)
|
||||
else:
|
||||
initstates_ptr += offs_m * stride_initstates_dim
|
||||
initstates_ptrs = initstates_ptr
|
||||
# - for cont batches, for the first chunk mean it will be the first batch's
|
||||
# init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
out_ptrs += stride_out_chunk
|
||||
prev_seq_idx_chunk_end = 0
|
||||
logical_chunk_idx = 0
|
||||
for c in range(nchunks):
|
||||
new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32)
|
||||
dA_cs = tl.load(dA_cs_ptr).to(tl.float32)
|
||||
scale_mask = True
|
||||
if HAS_SEQ_IDX:
|
||||
# - the seq to pass forward is the one that is flushed to the right
|
||||
# boundary.
|
||||
# - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk.
|
||||
seq_idx_chunk_end = tl.load(
|
||||
seq_idx_ptr
|
||||
+ (min((c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen
|
||||
)
|
||||
if HAS_INITSTATES:
|
||||
if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end:
|
||||
# this means in the current chunk the rightmost flushed seq
|
||||
# has changed.
|
||||
# - so we do not propagate the state from previous chunk
|
||||
# - but rather we load that sequence's init state
|
||||
initstates_ptrs = (
|
||||
initstates_ptr + seq_idx_chunk_end * stride_initstates_batch
|
||||
)
|
||||
|
||||
# - update state with seq_idx_new's init state
|
||||
states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(
|
||||
tl.float32
|
||||
)
|
||||
|
||||
# - we need to consider the cumsum only of the last sequence in the chunk
|
||||
# - find its starting position (given by c_off of the logical chunk index)
|
||||
# - and subtract the cumsum just before that position from the total cumsum
|
||||
# - first, update the logical chunk index (add the number of sequences in the current physical chunk):
|
||||
# sequence index at the start of the current chunk
|
||||
seq_idx_chunk_start = tl.load(
|
||||
seq_idx_ptr
|
||||
+ min(c * chunk_size, seqlen) * stride_seq_idx_seqlen
|
||||
)
|
||||
logical_chunk_idx += seq_idx_chunk_end - seq_idx_chunk_start
|
||||
# - load the chunk offset:
|
||||
c_off = tl.load(
|
||||
chunk_offsets_ptr + logical_chunk_idx,
|
||||
mask=logical_chunk_idx < chunk_meta_num,
|
||||
other=0,
|
||||
)
|
||||
# - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything
|
||||
if c_off > 0:
|
||||
# - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset
|
||||
dA_cs_boundary = tl.load(
|
||||
dA_cs_ptr
|
||||
- (chunk_size - 1) * stride_dA_cs_csize
|
||||
+ (c_off - 1) * stride_dA_cs_csize,
|
||||
mask=(c_off - 1) > -1 and c_off < chunk_size,
|
||||
other=0.0,
|
||||
)
|
||||
dA_cs -= dA_cs_boundary
|
||||
|
||||
# - increment logical chunk index for every physical chunk
|
||||
logical_chunk_idx += 1
|
||||
else:
|
||||
scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end
|
||||
prev_seq_idx_chunk_end = seq_idx_chunk_end
|
||||
|
||||
scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0)
|
||||
states = scale * states + new_states
|
||||
if c < nchunks - 1:
|
||||
tl.store(out_ptrs, states, mask=offs_m < dim)
|
||||
else:
|
||||
tl.store(final_states_ptrs, states, mask=offs_m < dim)
|
||||
states_ptrs += stride_states_chunk
|
||||
dA_cs_ptr += stride_dA_cs_chunk
|
||||
out_ptrs += stride_out_chunk
|
||||
|
||||
|
||||
def _state_passing_fwd(
|
||||
states,
|
||||
dA_cumsum,
|
||||
initial_states=None,
|
||||
seq_idx=None,
|
||||
chunk_size=None,
|
||||
out_dtype=None,
|
||||
is_cont_batched=False,
|
||||
chunk_offsets=None,
|
||||
):
|
||||
batch, nchunks, nheads, dim = states.shape
|
||||
if chunk_size is None:
|
||||
chunk_size = dA_cumsum.shape[-1]
|
||||
else:
|
||||
assert chunk_size == dA_cumsum.shape[-1]
|
||||
assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size)
|
||||
if initial_states is not None:
|
||||
if is_cont_batched:
|
||||
# - if cu_seqlens is provided, then the initial states
|
||||
# are used for continuous batching. In which case we
|
||||
# require seq_idx to be provided
|
||||
assert (
|
||||
seq_idx is not None
|
||||
), "seq_idx must be provided for continuous batching"
|
||||
# - we also need chunk_offsets to be provided, to account
|
||||
# for computation of dA_cumsum from the start of the
|
||||
# sequence
|
||||
assert (
|
||||
chunk_offsets is not None
|
||||
), "chunk_offsets must be provided for continuous batching"
|
||||
else:
|
||||
# - this is the regular batching case, where initial
|
||||
# states are used are for each example of the batch.
|
||||
assert initial_states.shape == (batch, nheads, dim)
|
||||
|
||||
if seq_idx is not None:
|
||||
seqlen = seq_idx.shape[-1]
|
||||
assert seq_idx.shape == (batch, seqlen)
|
||||
out_dtype = states.dtype if out_dtype is None else out_dtype
|
||||
out = torch.empty(
|
||||
(batch, nchunks, nheads, dim), device=states.device, dtype=out_dtype
|
||||
)
|
||||
final_states = torch.empty(
|
||||
(batch, nheads, dim), device=states.device, dtype=torch.float32
|
||||
)
|
||||
grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), batch, nheads)
|
||||
with torch.cuda.device(states.device.index):
|
||||
_state_passing_fwd_kernel[grid](
|
||||
states,
|
||||
out,
|
||||
final_states,
|
||||
dA_cumsum,
|
||||
initial_states,
|
||||
seq_idx,
|
||||
chunk_offsets,
|
||||
len(chunk_offsets) if chunk_offsets is not None else 0,
|
||||
dim,
|
||||
nchunks,
|
||||
seqlen if seq_idx is not None else 0,
|
||||
chunk_size,
|
||||
states.stride(0),
|
||||
states.stride(1),
|
||||
states.stride(2),
|
||||
states.stride(3),
|
||||
out.stride(0),
|
||||
out.stride(1),
|
||||
out.stride(2),
|
||||
out.stride(3),
|
||||
final_states.stride(0),
|
||||
final_states.stride(1),
|
||||
final_states.stride(2),
|
||||
dA_cumsum.stride(0),
|
||||
dA_cumsum.stride(2),
|
||||
dA_cumsum.stride(1),
|
||||
dA_cumsum.stride(3),
|
||||
*(
|
||||
(
|
||||
initial_states.stride(0),
|
||||
initial_states.stride(1),
|
||||
initial_states.stride(2),
|
||||
)
|
||||
if initial_states is not None
|
||||
else (0, 0, 0)
|
||||
),
|
||||
*(
|
||||
(seq_idx.stride(0), seq_idx.stride(1))
|
||||
if seq_idx is not None
|
||||
else (0, 0)
|
||||
),
|
||||
HAS_INITSTATES=initial_states is not None,
|
||||
HAS_SEQ_IDX=seq_idx is not None,
|
||||
IS_CONT_BATCHED=is_cont_batched,
|
||||
)
|
||||
return out, final_states
|
||||
@@ -1297,6 +1297,7 @@ class ModelRunner:
|
||||
return self.model_config.hf_config.architectures[0] in [
|
||||
"Qwen3NextForCausalLM",
|
||||
"Qwen3NextForCausalLMMTP",
|
||||
"FalconH1ForCausalLM",
|
||||
]
|
||||
|
||||
def set_num_token_hybrid(self):
|
||||
|
||||
576
python/sglang/srt/models/falcon_h1.py
Normal file
576
python/sglang/srt/models/falcon_h1.py
Normal file
@@ -0,0 +1,576 @@
|
||||
import enum
|
||||
import logging
|
||||
from typing import Any, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
||||
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
is_dp_attention_enabled,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
|
||||
class FalconH1MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
layer_id: int,
|
||||
mlp_multipliers: List[float],
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
reduce_results: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size,
|
||||
[intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("down_proj", prefix),
|
||||
reduce_results=reduce_results,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.gate_multiplier, self.down_multiplier = mlp_multipliers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
forward_batch=None,
|
||||
use_reduce_scatter: bool = False,
|
||||
):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
|
||||
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(
|
||||
x,
|
||||
skip_all_reduce=use_reduce_scatter,
|
||||
)
|
||||
x = x * self.down_multiplier
|
||||
return x
|
||||
|
||||
|
||||
class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconH1Config,
|
||||
layer_id: int,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
alt_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = config.num_attention_heads
|
||||
assert self.total_num_heads % self.attn_tp_size == 0
|
||||
self.num_heads = self.total_num_heads // self.attn_tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
if self.total_num_kv_heads >= self.attn_tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % self.attn_tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.attn_tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
|
||||
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
rope_scaling=self.rope_scaling,
|
||||
base=self.rope_theta,
|
||||
partial_rotary_factor=self.partial_rotary_factor,
|
||||
is_neox_style=True,
|
||||
dtype=torch.get_default_dtype(), # see impl of get_rope
|
||||
)
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
config.hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=self.attn_tp_rank,
|
||||
tp_size=self.attn_tp_size,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
tp_rank=self.attn_tp_rank,
|
||||
tp_size=self.attn_tp_size,
|
||||
)
|
||||
|
||||
self.attn = RadixAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
layer_id=layer_id,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
self.d_ssm = (
|
||||
int(config.mamba_expand * config.hidden_size)
|
||||
if config.mamba_d_ssm is None
|
||||
else config.mamba_d_ssm
|
||||
)
|
||||
|
||||
self.mamba = MambaMixer2(
|
||||
hidden_size=config.hidden_size,
|
||||
ssm_state_size=config.mamba_d_state,
|
||||
conv_kernel_size=config.mamba_d_conv,
|
||||
intermediate_size=self.d_ssm,
|
||||
use_conv_bias=config.mamba_conv_bias,
|
||||
use_bias=config.mamba_proj_bias,
|
||||
n_groups=config.mamba_n_groups,
|
||||
num_heads=config.mamba_n_heads,
|
||||
layer_id=layer_id,
|
||||
head_dim=config.mamba_d_head,
|
||||
rms_norm_eps=config.rms_norm_eps,
|
||||
chunk_size=config.mamba_chunk_size,
|
||||
activation=config.hidden_act,
|
||||
use_rms_norm=config.mamba_rms_norm,
|
||||
prefix=f"{prefix}.mixer",
|
||||
)
|
||||
|
||||
# FalconH1 all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = False
|
||||
is_previous_layer_sparse = False
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||
)
|
||||
|
||||
self.feed_forward = FalconH1MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
layer_id=layer_id,
|
||||
mlp_multipliers=config.mlp_multipliers,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
||||
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.pre_ff_layernorm,
|
||||
allow_reduce_scatter=True,
|
||||
)
|
||||
|
||||
self.alt_stream = alt_stream
|
||||
self.key_multiplier = config.key_multiplier
|
||||
|
||||
self.ssm_out_multiplier = config.ssm_out_multiplier
|
||||
self.ssm_in_multiplier = config.ssm_in_multiplier
|
||||
|
||||
self.attention_in_multiplier = config.attention_in_multiplier
|
||||
self.attn_out_multiplier = config.attention_out_multiplier
|
||||
|
||||
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
||||
self.zxbcdt_multipliers = config.ssm_multipliers
|
||||
self._init_mup_vector()
|
||||
|
||||
def _init_mup_vector(self):
|
||||
"""
|
||||
Non learnable per-block scaling vector composed of element-wise
|
||||
multipliersapplied to each separate contiguous block of the output
|
||||
of the linear projection (in_proj) before further processing
|
||||
(gating, convolution, SSM):
|
||||
|
||||
- Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
|
||||
- X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
|
||||
- B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
|
||||
- C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
|
||||
→ zxbcdt_multipliers[3]
|
||||
- dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
|
||||
|
||||
where:
|
||||
- d_ssm: Dimension of state-space model latent
|
||||
- G: Number of groups (n_groups)
|
||||
- S: SSM state size per group
|
||||
- All indices are divided by tp_size to support tensor parallelism
|
||||
"""
|
||||
vector_shape = (
|
||||
2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
|
||||
) // self.tp_size
|
||||
mup_vector = torch.ones(1, vector_shape)
|
||||
# Z vector 0 -> d_ssm
|
||||
mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
|
||||
# X vector d_ssm -> 2 * d_ssm
|
||||
mup_vector[
|
||||
:, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
|
||||
] *= self.zxbcdt_multipliers[1]
|
||||
# B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
|
||||
mup_vector[
|
||||
:,
|
||||
(2 * self.d_ssm)
|
||||
// self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
|
||||
// self.tp_size,
|
||||
] *= self.zxbcdt_multipliers[2]
|
||||
# C vector 2 * d_ssm + (n_group * d_state)
|
||||
# -> 2 * d_ssm + 2 * (n_group * d_state)
|
||||
mup_vector[
|
||||
:,
|
||||
(2 * self.d_ssm + self.groups_time_state_size)
|
||||
// self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
|
||||
// self.tp_size,
|
||||
] *= self.zxbcdt_multipliers[3]
|
||||
# dt vector 2 * d_ssm + 2 * (n_group * d_state)
|
||||
# -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
|
||||
mup_vector[
|
||||
:,
|
||||
(2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
|
||||
] *= self.zxbcdt_multipliers[4]
|
||||
|
||||
self.register_buffer("mup_vector", mup_vector, persistent=False)
|
||||
|
||||
def self_attention(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
k = k * self.key_multiplier
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
forward_batch: ForwardBatch,
|
||||
**kwargs: Any,
|
||||
):
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
# Attention block
|
||||
attention_hidden_states = self.self_attention(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states * self.attention_in_multiplier,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
||||
|
||||
# Mamba block
|
||||
mamba_hidden_states = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states * self.ssm_in_multiplier,
|
||||
mamba_hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
mup_vector=self.mup_vector,
|
||||
)
|
||||
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
||||
|
||||
hidden_states = attention_hidden_states + mamba_hidden_states
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
||||
forward_batch
|
||||
)
|
||||
hidden_states = self.feed_forward(
|
||||
hidden_states, forward_batch, use_reduce_scatter
|
||||
)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
ALL_DECODER_LAYER_TYPES = {
|
||||
"falcon_h1": FalconH1HybridAttentionDecoderLayer,
|
||||
}
|
||||
|
||||
|
||||
class FalconH1Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconH1Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||
self.embedding_multiplier = config.embedding_multiplier
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
enable_tp=not is_dp_attention_enabled(),
|
||||
)
|
||||
|
||||
def get_layer(idx: int, prefix: str):
|
||||
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
|
||||
return layer_class(
|
||||
config,
|
||||
idx,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
alt_stream=alt_stream,
|
||||
)
|
||||
|
||||
self.layers = make_layers(
|
||||
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
||||
)
|
||||
|
||||
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.infer_count = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
# mamba_cache_params: MambaCacheParams,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
# pass a sequence index tensor, that is required for
|
||||
# proper continuous batching computation including
|
||||
# chunked prefill
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds * self.embedding_multiplier
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
|
||||
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(
|
||||
layer_id=i,
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
residual=residual,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
if residual is None:
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HybridLayerType(enum.Enum):
|
||||
full_attention = "attention"
|
||||
swa_attention = "swa_attention"
|
||||
linear_attention = "linear_attention"
|
||||
mamba2 = "mamba"
|
||||
|
||||
|
||||
class FalconH1ForCausalLM(nn.Module):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FalconH1Config,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.pp_group = get_pp_group()
|
||||
assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
|
||||
self.quant_config = quant_config
|
||||
self.model = FalconH1Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
org_num_embeddings=config.vocab_size,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
||||
)
|
||||
self.lm_head = self.lm_head.float()
|
||||
self.lm_head_multiplier = config.lm_head_multiplier
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config, logit_scale=self.lm_head_multiplier
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
):
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
|
||||
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def get_embed_and_head(self):
|
||||
return self.model.embed_tokens.weight, self.lm_head.weight
|
||||
|
||||
def set_embed_and_head(self, embed, head):
|
||||
del self.model.embed_tokens.weight
|
||||
del self.lm_head.weight
|
||||
self.model.embed_tokens.weight = embed
|
||||
self.lm_head.weight = head
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load_weights(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
|
||||
) -> Set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
if ".self_attn." in name:
|
||||
name = name.replace(".self_attn", "")
|
||||
|
||||
if "A_log" in name:
|
||||
name = name.replace("A_log", "A")
|
||||
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Skip layers on other devices.
|
||||
# if is_pp_missing_parameter(name, self):
|
||||
# continue
|
||||
if name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader")
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# if is_pp_missing_parameter(name, self):
|
||||
# continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
EntryClass = FalconH1ForCausalLM
|
||||
Reference in New Issue
Block a user