model: Support Hybrid Mamba2 NemotronHForCausalLM (nvidia/NVIDIA-Nemotron-Nano-9B-v2) (#10909)
Signed-off-by: Netanel Haber <nhaber@nvidia.com>
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# evade circular imports
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
ATTENTION_BACKENDS = {}
|
||||
|
||||
|
||||
@@ -166,36 +173,41 @@ def create_dual_chunk_flash_attn_backend(runner):
|
||||
return DualChunkFlashAttentionBackend(runner)
|
||||
|
||||
|
||||
def attn_backend_wrapper(runner, full_attn_backend):
|
||||
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
|
||||
"""
|
||||
Wrapper for special models like hybrid GDN, so we don't
|
||||
need to change the code of the original attention backend.
|
||||
"""
|
||||
assert not (
|
||||
runner.is_hybrid_gdn and runner.use_mla_backend
|
||||
runner.hybrid_gdn_config is not None and runner.use_mla_backend
|
||||
), "hybrid_gdn can only be used with non-MLA models."
|
||||
|
||||
# wrap for hybrid GDN models
|
||||
if runner.is_hybrid_gdn:
|
||||
if cfg := runner.mambaish_config:
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
GDNAttnBackend,
|
||||
HybridLinearAttnBackend,
|
||||
Mamba2AttnBackend,
|
||||
)
|
||||
from sglang.srt.utils import is_blackwell, is_npu
|
||||
|
||||
if is_blackwell():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "triton"
|
||||
or runner.server_args.attention_backend == "trtllm_mha"
|
||||
), "triton or trtllm_mha backend are the only supported backends on Blackwell GPUs for hybrid GDN models, use --attention-backend triton or --attention-backend trtllm_mha to specify the backend."
|
||||
if is_npu():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "ascend"
|
||||
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
||||
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||
HybridLinearAttnBackend,
|
||||
MambaAttnBackend,
|
||||
)
|
||||
|
||||
linear_attn_backend = MambaAttnBackend(runner)
|
||||
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
||||
if runner.hybrid_gdn_config is not None:
|
||||
if is_blackwell():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "triton"
|
||||
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
||||
if is_npu():
|
||||
assert (
|
||||
runner.server_args.attention_backend == "ascend"
|
||||
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
||||
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
||||
linear_attn_backend = GDNAttnBackend(runner)
|
||||
elif runner.mamba2_config is not None:
|
||||
linear_attn_backend = Mamba2AttnBackend(runner)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected hybrid GDN or NemotronH models, but got unknown model."
|
||||
)
|
||||
full_attn_layers = cfg.full_attention_layer_ids
|
||||
return HybridLinearAttnBackend(
|
||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||
)
|
||||
|
||||
@@ -181,6 +181,45 @@ def _layer_norm_fwd(
|
||||
return out, mean, rstd
|
||||
|
||||
|
||||
def rms_norm_gated(
|
||||
*,
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
z=None,
|
||||
eps=1e-6,
|
||||
group_size=None,
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
class LayerNormFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
@@ -195,32 +234,16 @@ class LayerNormFn(torch.autograd.Function):
|
||||
norm_before_gate=True,
|
||||
is_rms_norm=False,
|
||||
):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
|
||||
x_shape_og = x.shape
|
||||
# reshape input data into 2D tensor
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if x.stride(-1) != 1:
|
||||
x = x.contiguous()
|
||||
if z is not None:
|
||||
assert z.shape == x_shape_og
|
||||
z = z.reshape(-1, z.shape[-1])
|
||||
if z.stride(-1) != 1:
|
||||
z = z.contiguous()
|
||||
weight = weight.contiguous()
|
||||
if bias is not None:
|
||||
bias = bias.contiguous()
|
||||
y, mean, rstd = _layer_norm_fwd(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
return rms_norm_gated(
|
||||
x=x,
|
||||
weight=weight,
|
||||
bias=bias,
|
||||
eps=eps,
|
||||
z=z,
|
||||
group_size=group_size,
|
||||
norm_before_gate=norm_before_gate,
|
||||
is_rms_norm=is_rms_norm,
|
||||
)
|
||||
return y.reshape(x_shape_og)
|
||||
|
||||
|
||||
def layernorm_fn(
|
||||
@@ -238,14 +261,6 @@ def layernorm_fn(
|
||||
)
|
||||
|
||||
|
||||
def rmsnorm_fn(
|
||||
x, weight, bias, z=None, eps=1e-6, group_size=None, norm_before_gate=True
|
||||
):
|
||||
return LayerNormFn.apply(
|
||||
x, weight, bias, z, eps, group_size, norm_before_gate, True
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -284,6 +299,7 @@ class LayerNorm(torch.nn.Module):
|
||||
group_size=self.group_size,
|
||||
eps=self.eps,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
is_rms_norm=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -315,7 +331,7 @@ class RMSNorm(torch.nn.Module):
|
||||
|
||||
def forward(self, x, z=None):
|
||||
"""If z is not None, we do norm(x) * silu(z) if norm_before_gate, else norm(x * silu(z))"""
|
||||
return rmsnorm_fn(
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight,
|
||||
self.bias,
|
||||
@@ -323,4 +339,5 @@ class RMSNorm(torch.nn.Module):
|
||||
eps=self.eps,
|
||||
group_size=self.group_size,
|
||||
norm_before_gate=self.norm_before_gate,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
|
||||
@@ -14,14 +14,21 @@ from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import (
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
PAD_SLOT_ID,
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
||||
from sglang.srt.layers.attention.mamba.mamba2_metadata import (
|
||||
ForwardMetadata,
|
||||
Mamba2Metadata,
|
||||
)
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, MambaPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.models.qwen3_next import fused_gdn_gating
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import is_cuda, is_npu
|
||||
|
||||
@@ -47,18 +54,10 @@ elif is_npu():
|
||||
causal_conv1d_update = causal_conv1d_update_npu
|
||||
|
||||
|
||||
@dataclass
|
||||
class ForwardMetadata:
|
||||
query_start_loc: Optional[torch.Tensor]
|
||||
mamba_cache_indices: torch.Tensor
|
||||
|
||||
|
||||
class MambaAttnBackend(AttentionBackend):
|
||||
"""Attention backend using Mamba kernel."""
|
||||
|
||||
class MambaAttnBackendBase(AttentionBackend):
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__()
|
||||
self.pad_slot_id = -1 # Default pad slot id
|
||||
self.pad_slot_id = PAD_SLOT_ID
|
||||
self.device = model_runner.device
|
||||
self.req_to_token_pool: HybridReqToTokenPool = model_runner.req_to_token_pool
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
@@ -67,7 +66,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None
|
||||
self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
def _forward_metadata(self, forward_batch: ForwardBatch):
|
||||
bs = forward_batch.batch_size
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
@@ -97,11 +96,43 @@ class MambaAttnBackend(AttentionBackend):
|
||||
mamba_cache_indices = self.req_to_token_pool.get_mamba_indices(
|
||||
forward_batch.req_pool_indices
|
||||
)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
return ForwardMetadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
)
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
self.forward_metadata = self._forward_metadata(forward_batch)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
self.forward_metadata = self._capture_metadata(
|
||||
bs, req_pool_indices, forward_mode
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
self.forward_metadata = self._replay_metadata(
|
||||
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
assert (
|
||||
max_num_tokens % max_bs == 0
|
||||
@@ -127,15 +158,8 @@ class MambaAttnBackend(AttentionBackend):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInput],
|
||||
def _capture_metadata(
|
||||
self, bs: int, req_pool_indices: torch.Tensor, forward_mode: ForwardMode
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
@@ -149,18 +173,15 @@ class MambaAttnBackend(AttentionBackend):
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
mamba_indices = self.req_to_token_pool.get_mamba_indices(req_pool_indices)
|
||||
self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices)
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
return ForwardMetadata(
|
||||
query_start_loc=self.query_start_loc_list[bs - 1],
|
||||
mamba_cache_indices=self.state_indices_list[bs - 1],
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
def _replay_metadata(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
@@ -200,7 +221,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
self.forward_metadata = ForwardMetadata(
|
||||
return ForwardMetadata(
|
||||
query_start_loc=self.query_start_loc_list[bs - 1],
|
||||
mamba_cache_indices=self.state_indices_list[bs - 1],
|
||||
)
|
||||
@@ -208,6 +229,10 @@ class MambaAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1 # Mamba attn does not use seq lens to index kv cache
|
||||
|
||||
|
||||
class GDNAttnBackend(MambaAttnBackendBase):
|
||||
"""Attention backend using Mamba kernel."""
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
@@ -233,9 +258,9 @@ class MambaAttnBackend(AttentionBackend):
|
||||
dt_bias = kwargs["dt_bias"]
|
||||
layer_id = kwargs["layer_id"]
|
||||
|
||||
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
||||
layer_id
|
||||
)
|
||||
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
||||
conv_states = layer_cache.conv
|
||||
ssm_states = layer_cache.temporal
|
||||
query_start_loc = self.forward_metadata.query_start_loc
|
||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||
|
||||
@@ -313,13 +338,13 @@ class MambaAttnBackend(AttentionBackend):
|
||||
query_start_loc = self.forward_metadata.query_start_loc
|
||||
cache_indices = self.forward_metadata.mamba_cache_indices
|
||||
|
||||
mamba_cache_params = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
||||
conv_states = mamba_cache_params.conv
|
||||
ssm_states = mamba_cache_params.temporal
|
||||
if is_target_verify:
|
||||
(
|
||||
conv_states,
|
||||
ssm_states,
|
||||
intermediate_state_cache,
|
||||
intermediate_conv_window_cache,
|
||||
) = self.req_to_token_pool.get_mamba_params(layer_id)
|
||||
assert isinstance(mamba_cache_params, MambaPool.SpeculativeState)
|
||||
intermediate_state_cache = mamba_cache_params.intermediate_ssm
|
||||
intermediate_conv_window_cache = mamba_cache_params.intermediate_conv_window
|
||||
has_initial_states = torch.ones(
|
||||
seq_len // forward_batch.spec_info.draft_token_num,
|
||||
dtype=torch.bool,
|
||||
@@ -327,9 +352,6 @@ class MambaAttnBackend(AttentionBackend):
|
||||
)
|
||||
conv_states_to_use = conv_states.clone()
|
||||
else:
|
||||
conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params(
|
||||
layer_id
|
||||
)
|
||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||
conv_states_to_use = conv_states
|
||||
|
||||
@@ -424,16 +446,100 @@ class MambaAttnBackend(AttentionBackend):
|
||||
return core_attn_out
|
||||
|
||||
|
||||
class Mamba2AttnBackend(MambaAttnBackendBase):
|
||||
"""Attention backend wrapper for Mamba2Mixer kernels."""
|
||||
|
||||
def __init__(self, model_runner: ModelRunner):
|
||||
super().__init__(model_runner)
|
||||
config = model_runner.mamba2_config
|
||||
assert config is not None
|
||||
self.mamba_chunk_size = config.mamba_chunk_size
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
metadata = self._forward_metadata(forward_batch)
|
||||
self.forward_metadata = Mamba2Metadata.prepare_mixed(
|
||||
metadata.query_start_loc,
|
||||
metadata.mamba_cache_indices,
|
||||
self.mamba_chunk_size,
|
||||
forward_batch,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
num_tokens: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
metadata = self._capture_metadata(bs, req_pool_indices, forward_mode)
|
||||
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
||||
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
req_pool_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
metadata = self._replay_metadata(
|
||||
bs, req_pool_indices, forward_mode, spec_info, seq_lens_cpu
|
||||
)
|
||||
self.forward_metadata = Mamba2Metadata.prepare_decode(
|
||||
metadata.query_start_loc, metadata.mamba_cache_indices, seq_lens
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
mixer: MambaMixer2,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_id: int,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
use_triton_causal_conv: bool = False,
|
||||
):
|
||||
assert isinstance(self.forward_metadata, Mamba2Metadata)
|
||||
layer_cache = self.req_to_token_pool.mamba2_layer_cache(layer_id)
|
||||
return mixer.forward(
|
||||
hidden_states=hidden_states,
|
||||
output=output,
|
||||
layer_cache=layer_cache,
|
||||
metadata=self.forward_metadata,
|
||||
mup_vector=mup_vector,
|
||||
use_triton_causal_conv=use_triton_causal_conv,
|
||||
)
|
||||
|
||||
def forward_decode(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
||||
)
|
||||
|
||||
def forward_extend(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
"Mamba2AttnBackend's forward is called directly instead of through HybridLinearAttnBackend, as it supports mixed prefill and decode"
|
||||
)
|
||||
|
||||
|
||||
class HybridLinearAttnBackend(AttentionBackend):
|
||||
"""Support different backends for prefill and decode."""
|
||||
"""Manages a full and linear attention backend"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_attn_backend: AttentionBackend,
|
||||
linear_attn_backend: AttentionBackend,
|
||||
linear_attn_backend: MambaAttnBackendBase,
|
||||
full_attn_layers: list[int],
|
||||
):
|
||||
self.full_attn_layers = full_attn_layers
|
||||
self.full_attn_backend = full_attn_backend
|
||||
self.linear_attn_backend = linear_attn_backend
|
||||
self.attn_backend_list = [full_attn_backend, linear_attn_backend]
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
@@ -489,7 +595,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return self.attn_backend_list[0].get_cuda_graph_seq_len_fill_value()
|
||||
return self.full_attn_backend.get_cuda_graph_seq_len_fill_value()
|
||||
|
||||
def forward_decode(
|
||||
self,
|
||||
@@ -503,10 +609,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
):
|
||||
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
||||
if layer_id in self.full_attn_layers:
|
||||
return self.attn_backend_list[0].forward_decode(
|
||||
return self.full_attn_backend.forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
return self.attn_backend_list[1].forward_decode(
|
||||
return self.linear_attn_backend.forward_decode(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
@@ -522,10 +628,10 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
):
|
||||
layer_id = layer.layer_id if layer else kwargs["layer_id"]
|
||||
if layer_id in self.full_attn_layers:
|
||||
return self.attn_backend_list[0].forward_extend(
|
||||
return self.full_attn_backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
return self.attn_backend_list[1].forward_extend(
|
||||
return self.linear_attn_backend.forward_extend(
|
||||
q, k, v, layer, forward_batch, save_kv_cache, **kwargs
|
||||
)
|
||||
|
||||
@@ -568,20 +674,20 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
def update_mamba_state_after_mtp_verify(self, accepted_length, model):
|
||||
request_number = accepted_length.shape[0]
|
||||
|
||||
state_indices_tensor = self.attn_backend_list[
|
||||
1
|
||||
].forward_metadata.mamba_cache_indices[:request_number]
|
||||
state_indices_tensor = (
|
||||
self.linear_attn_backend.forward_metadata.mamba_cache_indices[
|
||||
:request_number
|
||||
]
|
||||
)
|
||||
|
||||
mamba_caches = self.attn_backend_list[
|
||||
1
|
||||
].req_to_token_pool.get_mamba_params_all_layers()
|
||||
mamba_caches = (
|
||||
self.linear_attn_backend.req_to_token_pool.get_speculative_mamba2_params_all_layers()
|
||||
)
|
||||
|
||||
(
|
||||
conv_states,
|
||||
ssm_states,
|
||||
intermediate_state_cache,
|
||||
intermediate_conv_window_cache,
|
||||
) = mamba_caches
|
||||
conv_states = mamba_caches.conv
|
||||
ssm_states = mamba_caches.temporal
|
||||
intermediate_state_cache = mamba_caches.intermediate_ssm
|
||||
intermediate_conv_window_cache = mamba_caches.intermediate_conv_window
|
||||
|
||||
# SSM state updates (chunked to reduce peak memory)
|
||||
valid_mask = accepted_length > 0
|
||||
|
||||
@@ -10,7 +10,7 @@ import torch
|
||||
from sgl_kernel import causal_conv1d_fwd
|
||||
from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
from .causal_conv1d_triton import PAD_SLOT_ID
|
||||
|
||||
|
||||
def causal_conv1d_fn(
|
||||
|
||||
@@ -6,11 +6,11 @@ from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@triton.jit()
|
||||
def _causal_conv1d_fwd_kernel( # continuous batching
|
||||
@@ -672,7 +672,9 @@ def _causal_conv1d_update_kernel(
|
||||
+ (conv_state_batch_coord * stride_conv_state_seq)
|
||||
+ conv_state_token_offset * stride_conv_state_tok
|
||||
+ (idx_feats * stride_conv_state_dim)[None, :]
|
||||
+ ((idx_tokens + 1) * stride_conv_state_tok)[:, None]
|
||||
+ ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[
|
||||
:, None
|
||||
]
|
||||
) # [BLOCK_M, BLOCK_N]
|
||||
mask = (
|
||||
(conv_state_batch_coord < num_cache_lines)
|
||||
@@ -897,7 +899,10 @@ def causal_conv1d_update(
|
||||
stride_state_indices = (
|
||||
conv_state_indices.stride(0) if conv_state_indices is not None else 0
|
||||
)
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
if num_accepted_tokens is not None:
|
||||
state_len = width - 1 + (seqlen - 1) # effective state_len needed
|
||||
else:
|
||||
state_len = width - 1
|
||||
np2_statelen = triton.next_power_of_2(state_len)
|
||||
|
||||
def grid(META):
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.configs.mamba_utils import (
|
||||
Mamba2CacheParams,
|
||||
extra_groups_for_head_shards,
|
||||
)
|
||||
from sglang.srt.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.utils import divide
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d import (
|
||||
causal_conv1d_fn,
|
||||
causal_conv1d_update,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
causal_conv1d_fn as causal_conv1d_fn_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.causal_conv1d_triton import (
|
||||
causal_conv1d_update as causal_conv1d_update_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.mamba.mamba2_metadata import Mamba2Metadata
|
||||
from sglang.srt.layers.attention.mamba.mixer2_rms_norm_gated import Mixer2RMSNormGated
|
||||
from sglang.srt.layers.attention.mamba.ops import (
|
||||
mamba_chunk_scan_combined,
|
||||
selective_state_update,
|
||||
@@ -28,7 +35,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.mem_cache.memory_pool import MambaPool
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
composed_weight_loader,
|
||||
sharded_weight_loader,
|
||||
@@ -97,110 +104,6 @@ def mamba_v2_sharded_weight_loader(
|
||||
return loader
|
||||
|
||||
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (
|
||||
self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return layernorm_fn(
|
||||
x,
|
||||
self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
)
|
||||
|
||||
|
||||
class MambaMixer2(torch.nn.Module):
|
||||
"""
|
||||
Compute ∆, A, B, C, and D the state space parameters and compute
|
||||
@@ -214,22 +117,14 @@ class MambaMixer2(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_params: Mamba2CacheParams,
|
||||
hidden_size: int,
|
||||
ssm_state_size: int,
|
||||
conv_kernel_size: int,
|
||||
intermediate_size: int,
|
||||
use_conv_bias: bool,
|
||||
use_bias: bool,
|
||||
chunk_size: int,
|
||||
layer_id: int,
|
||||
n_groups: int = 1,
|
||||
num_heads: int = 128,
|
||||
head_dim: int = 64,
|
||||
rms_norm_eps: float = 1e-5,
|
||||
activation: str = "silu",
|
||||
use_rms_norm: bool = True,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
# cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
@@ -252,6 +147,9 @@ class MambaMixer2(torch.nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.num_heads = num_heads = cache_params.shape.num_heads
|
||||
self.head_dim = cache_params.shape.head_dim
|
||||
|
||||
assert (
|
||||
num_heads % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide num heads."
|
||||
@@ -261,57 +159,76 @@ class MambaMixer2(torch.nn.Module):
|
||||
"then num_groups must equal 1."
|
||||
)
|
||||
|
||||
self.ssm_state_size = ssm_state_size
|
||||
self.conv_kernel_size = conv_kernel_size
|
||||
assert (
|
||||
(n_groups % self.tp_size == 0) or self.tp_size == 1 or quant_config is None
|
||||
), (
|
||||
"Tensor parallel currently supported for quantized models only "
|
||||
"if tensor parallel world size divides num groups."
|
||||
)
|
||||
|
||||
self.ssm_state_size = cache_params.shape.ssm_state_size
|
||||
self.activation = activation
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.intermediate_size = intermediate_size
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
conv_kernel_size = cache_params.shape.conv_kernel
|
||||
self.intermediate_size = intermediate_size = (
|
||||
cache_params.shape.intermediate_size
|
||||
)
|
||||
self.n_groups = n_groups
|
||||
if n_groups % self.tp_size != 0:
|
||||
# - for TP we shard conv_dim by sharding on n_groups,
|
||||
# - but if n_groups cannot divide tp_size, we need to
|
||||
# extend some extra groups
|
||||
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
|
||||
n_groups, self.tp_size
|
||||
)
|
||||
groups = extra_groups_for_head_shards(n_groups, self.tp_size)
|
||||
self.n_groups = n_groups + groups
|
||||
|
||||
self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
|
||||
self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
|
||||
self.conv_dim = cache_params.shape.conv_dim
|
||||
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
if n_groups % self.tp_size == 0:
|
||||
self.conv1d = MergedColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
],
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
if n_groups % self.tp_size != 0:
|
||||
self.in_proj = MergedColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_sizes=[
|
||||
intermediate_size,
|
||||
intermediate_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.groups_ssm_state_size,
|
||||
self.num_heads,
|
||||
],
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
else:
|
||||
# This is the n_groups == 1 case,
|
||||
# where we need to duplicate groups if TP>1.
|
||||
|
||||
self.conv1d = ColumnParallelLinear(
|
||||
input_size=conv_kernel_size,
|
||||
output_size=self.conv_dim,
|
||||
bias=use_conv_bias,
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.conv1d",
|
||||
)
|
||||
|
||||
self.in_proj = ColumnParallelLinear(
|
||||
input_size=hidden_size,
|
||||
output_size=intermediate_size + self.conv_dim + self.num_heads,
|
||||
bias=use_bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.in_proj",
|
||||
)
|
||||
|
||||
# - because in_proj is a concatenation of 3 weights, we
|
||||
# need to interleave them before sharding
|
||||
# - use the custom weight loader mamba_v2_sharded_weight_loader
|
||||
@@ -421,47 +338,27 @@ class MambaMixer2(torch.nn.Module):
|
||||
intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
|
||||
)
|
||||
|
||||
# The tuple is (conv_state, ssm_state)
|
||||
self.kv_cache = (torch.tensor([]), torch.tensor([]))
|
||||
|
||||
self.model_config = model_config
|
||||
self.prefix = prefix
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
layer_cache: MambaPool.State,
|
||||
metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
use_triton_causal_conv: bool = False,
|
||||
):
|
||||
# attn_backend_list[-1] gives access to MambaAttnBackend
|
||||
mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
|
||||
attn_metadata = mamba_backend.forward_metadata
|
||||
state_indices_tensor = attn_metadata.mamba_cache_indices
|
||||
chunk_size = self.chunk_size
|
||||
# metadata contains metadata necessary for the mamba2 triton
|
||||
# kernels to operate in continuous batching and in chunked prefill
|
||||
# modes; they are computed at top-level model forward since they
|
||||
# stay the same and reused for all mamba layers in the same iteration
|
||||
state_indices_tensor = metadata.mamba_cache_indices
|
||||
conv_state = layer_cache.conv
|
||||
ssm_state = layer_cache.temporal
|
||||
|
||||
conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
|
||||
self.layer_id
|
||||
)
|
||||
|
||||
assert (
|
||||
ssm_state.size(1) == self.ssm_state_size
|
||||
), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
|
||||
|
||||
query_start_loc = attn_metadata.query_start_loc
|
||||
|
||||
chunk_size = self.chunk_size
|
||||
|
||||
# TODO: properly support this
|
||||
prep_initial_states = False
|
||||
query_start_loc = metadata.query_start_loc
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@@ -493,6 +390,38 @@ class MambaMixer2(torch.nn.Module):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
num_prefills = metadata.num_prefills # request count
|
||||
num_decodes = metadata.num_decodes # token count (=request)
|
||||
num_prefill_tokens = metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
assert num_actual_tokens == projected_states.shape[0]
|
||||
|
||||
# NOTE: V0 put prefill before decode
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
# Split along token dimension
|
||||
hidden_states_B_C_p, hidden_states_B_C_d = torch.split(
|
||||
hidden_states_B_C,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
dt_p, dt_d = torch.split(
|
||||
dt,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_p, state_indices_tensor_d = torch.split(
|
||||
state_indices_tensor,
|
||||
[num_prefills, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = query_start_loc[: num_prefills + 1] if has_prefill else None
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
|
||||
preallocated_ssm_out = torch.empty(
|
||||
[
|
||||
projected_states.shape[0],
|
||||
@@ -501,128 +430,147 @@ class MambaMixer2(torch.nn.Module):
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
preallocated_ssm_out_p, preallocated_ssm_out_d = torch.split(
|
||||
preallocated_ssm_out,
|
||||
[num_prefill_tokens, num_decodes],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Process prefill requests
|
||||
if forward_batch.forward_mode.is_extend():
|
||||
if has_prefill:
|
||||
mixed_metadata = metadata.mixed_metadata
|
||||
assert mixed_metadata is not None
|
||||
# 2. Convolution sequence transformation
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "state_indices_tensor"
|
||||
num_prefill_tokens = forward_batch.extend_num_tokens or 0
|
||||
has_initial_states = forward_batch.extend_prefix_lens > 0
|
||||
cache_indices = attn_metadata.mamba_cache_indices
|
||||
|
||||
x = hidden_states_B_C.transpose(
|
||||
has_initial_states_p = mixed_metadata.has_initial_states
|
||||
prep_initial_states = mixed_metadata.prep_initial_states
|
||||
cache_indices = state_indices_tensor_p
|
||||
x = hidden_states_B_C_p.transpose(
|
||||
0, 1
|
||||
) # this is the form that causal-conv see
|
||||
hidden_states_B_C = causal_conv1d_fn(
|
||||
ccfn = (
|
||||
causal_conv1d_fn
|
||||
if not use_triton_causal_conv
|
||||
else causal_conv1d_fn_triton
|
||||
)
|
||||
hidden_states_B_C_p = ccfn(
|
||||
x,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
conv_states=conv_state,
|
||||
has_initial_state=has_initial_states,
|
||||
has_initial_state=has_initial_states_p,
|
||||
cache_indices=cache_indices,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
).transpose(0, 1)
|
||||
query_start_loc=query_start_loc_p,
|
||||
seq_lens_cpu=mixed_metadata.extend_seq_lens_cpu,
|
||||
).transpose(0, 1)[:num_prefill_tokens]
|
||||
|
||||
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
hidden_states_p, B_p, C_p = split_hidden_states_B_C_fn(hidden_states_B_C_p)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
initial_states = None
|
||||
|
||||
if has_initial_states is not None and prep_initial_states:
|
||||
if has_initial_states_p is not None and prep_initial_states:
|
||||
initial_states = torch.where(
|
||||
has_initial_states[:, None, None, None],
|
||||
ssm_state[state_indices_tensor],
|
||||
has_initial_states_p[:, None, None, None],
|
||||
ssm_state[state_indices_tensor_p],
|
||||
0,
|
||||
)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states.view(
|
||||
hidden_states_p.view(
|
||||
1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
|
||||
),
|
||||
dt.unsqueeze(0),
|
||||
dt_p.unsqueeze(0),
|
||||
self.A,
|
||||
B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=chunk_size,
|
||||
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
|
||||
chunk_size=mixed_metadata.chunk_size,
|
||||
D=self.D,
|
||||
z=None,
|
||||
dt_bias=self.dt_bias,
|
||||
cu_seqlens=query_start_loc,
|
||||
seq_idx=mixed_metadata.seq_idx,
|
||||
chunk_indices=mixed_metadata.chunk_indices,
|
||||
chunk_offsets=mixed_metadata.chunk_offsets,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
|
||||
out=preallocated_ssm_out_p.view(
|
||||
1, num_prefill_tokens, -1, self.head_dim
|
||||
),
|
||||
state_dtype=ssm_state.dtype,
|
||||
)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
|
||||
elif forward_batch.forward_mode.is_decode():
|
||||
num_decodes = len(query_start_loc) - 1
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
# 2. Convolution sequence transformation
|
||||
hidden_states_B_C = causal_conv1d_update(
|
||||
hidden_states_B_C,
|
||||
ccu = (
|
||||
causal_conv1d_update
|
||||
if not use_triton_causal_conv
|
||||
else causal_conv1d_update_triton
|
||||
)
|
||||
hidden_states_B_C_d = ccu(
|
||||
hidden_states_B_C_d,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=state_indices_tensor,
|
||||
conv_state_indices=state_indices_tensor_d,
|
||||
)
|
||||
|
||||
hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
|
||||
hidden_states_d, B_d, C_d = split_hidden_states_B_C_fn(hidden_states_B_C_d)
|
||||
|
||||
# 3. State Space Model sequence transformation
|
||||
n_groups = self.n_groups // self.tp_size
|
||||
A = (
|
||||
A_d = (
|
||||
self.A[:, None, ...][:, :, None]
|
||||
.expand(-1, self.head_dim, self.ssm_state_size)
|
||||
.to(dtype=torch.float32)
|
||||
)
|
||||
dt = dt[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_d = dt_d[:, :, None].expand(-1, -1, self.head_dim)
|
||||
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
|
||||
D = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B = B.view(-1, n_groups, B.shape[1] // n_groups)
|
||||
C = C.view(-1, n_groups, C.shape[1] // n_groups)
|
||||
hidden_states = hidden_states.view(
|
||||
D_d = self.D[:, None, ...].expand(-1, self.head_dim)
|
||||
B_d = B_d.view(-1, n_groups, B_d.shape[1] // n_groups)
|
||||
C_d = C_d.view(-1, n_groups, C_d.shape[1] // n_groups)
|
||||
hidden_states_d = hidden_states_d.view(
|
||||
-1, self.num_heads // self.tp_size, self.head_dim
|
||||
)
|
||||
|
||||
# - the hidden is reshaped into (bs, num_heads, head_dim)
|
||||
# - mamba_cache_params.ssm_state's slots will be selected
|
||||
# - layer_state.ssm_state's slots will be selected
|
||||
# using state_indices_tensor_d
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
selective_state_update(
|
||||
ssm_state.permute(0, 3, 2, 1),
|
||||
hidden_states,
|
||||
dt,
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
D,
|
||||
ssm_state,
|
||||
hidden_states_d,
|
||||
dt_d,
|
||||
A_d,
|
||||
B_d,
|
||||
C_d,
|
||||
D_d,
|
||||
z=None,
|
||||
dt_bias=dt_bias,
|
||||
dt_softplus=True,
|
||||
state_batch_indices=state_indices_tensor,
|
||||
out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
|
||||
state_batch_indices=state_indices_tensor_d,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
)
|
||||
elif forward_batch.forward_mode.is_idle():
|
||||
preallocated_ssm_out = preallocated_ssm_out
|
||||
|
||||
# 4. gated MLP
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate)
|
||||
hidden_states = self.norm(preallocated_ssm_out, gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
output[:], _ = self.out_proj(hidden_states)
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
@property
|
||||
def mamba_type(self) -> str:
|
||||
|
||||
211
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
Normal file
211
python/sglang/srt/layers/attention/mamba/mamba2_metadata.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright 2025 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/2c58742dff8613a3bd7496f2008ce927e18d38d1/vllm/model_executor/layers/mamba/mamba2_metadata.py
|
||||
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class ForwardMetadata:
|
||||
query_start_loc: torch.Tensor
|
||||
mamba_cache_indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Mamba2Metadata(ForwardMetadata):
|
||||
"""stable metadata across all mamba2 layers in the forward pass"""
|
||||
|
||||
num_prefills: int
|
||||
num_prefill_tokens: int
|
||||
num_decodes: int
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class MixedMetadata:
|
||||
has_initial_states: torch.Tensor
|
||||
prep_initial_states: bool
|
||||
|
||||
chunk_size: int
|
||||
seq_idx: torch.Tensor
|
||||
chunk_indices: torch.Tensor
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
extend_seq_lens_cpu: list[int]
|
||||
|
||||
mixed_metadata: MixedMetadata | None = None
|
||||
"""`mixed_metadata` is used for extend/mixed requests"""
|
||||
|
||||
@staticmethod
|
||||
def _query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc: torch.Tensor, chunk_size: int, total_seqlens: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
query_start_loc (torch.Tensor): 1D tensor of cumulative sequence
|
||||
lengths, shape (num_seqs + 1,).
|
||||
The first element should be 0. Each entry represents the starting
|
||||
index of a sequence in the flattened token array.
|
||||
chunk_size (int): The size of each physical mamba chunk
|
||||
(number of tokens per chunk).
|
||||
total_seqlens (int): The total number of tokens in the batch.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
||||
- chunk_indices (torch.Tensor): 1D tensor of indices
|
||||
indicating the physical chunk for each logical chunk.
|
||||
- chunk_offsets (torch.Tensor): 1D tensor of offsets
|
||||
indicating the starting index of each logical chunk within
|
||||
its physical chunk.
|
||||
|
||||
This function computes the chunk indices and offsets for the given
|
||||
query_start_loc and chunk_size. Both are tensors of integers with length N,
|
||||
where N is the number of logical (pseudo) chunks.
|
||||
A logical chunk is a sequence of tokens that are all part of the same
|
||||
sequence and are all in the same physical mamba chunk.
|
||||
In other words, a logical chunk changes every time we cross a sequence
|
||||
boundary or a physical mamba chunk boundary.
|
||||
Logical chunks are needed to handle batched requests with initial states
|
||||
(see _state_passing_fwd and _chunk_scan_fwd).
|
||||
The chunk_indices tensor contains the index of the physical chunk for each
|
||||
logical chunk.
|
||||
The chunk_offsets tensor contains the offset (AKA starting index) of the
|
||||
logical chunk in the physical chunk.
|
||||
|
||||
Example:
|
||||
query_start_loc = [0, 5, 10]
|
||||
chunk_size = 8
|
||||
total_seqlens = 10
|
||||
-> chunk_indices = [0, 0, 1]
|
||||
-> chunk_offsets = [0, 5, 0]
|
||||
|
||||
In this example, we have 2 sequences, each with 5 tokens. The physical
|
||||
chunk size is 8 tokens.
|
||||
We have three logical chunks:
|
||||
- the first logical chunk starts at token 0 in the first physical chunk
|
||||
and contains all 5 tokens from the first sequence
|
||||
- the second logical chunk starts at token 5 in the first physical chunk
|
||||
and contains first 3 tokens from the second sequence
|
||||
- the third logical chunk starts at token 0 in the second physical chunk
|
||||
and contains the remaining 2 tokens from the second sequence
|
||||
"""
|
||||
|
||||
cu_seqlens = query_start_loc[1:] # remove prepended 0
|
||||
|
||||
# outputs will have length expansion of chunks that do not divide
|
||||
# chunk_size
|
||||
N = (
|
||||
math.ceil(total_seqlens / chunk_size)
|
||||
+ (cu_seqlens[:-1] % chunk_size > 0).sum()
|
||||
)
|
||||
chunk_indices = torch.arange(N, dtype=torch.int, device=query_start_loc.device)
|
||||
chunk_offsets = torch.zeros(
|
||||
(N,), dtype=torch.int, device=query_start_loc.device
|
||||
)
|
||||
|
||||
p = 0 # num of insertions
|
||||
for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]):
|
||||
|
||||
# if does not divide chunk_size, then there is one chunk insertion
|
||||
p += s % chunk_size > 0
|
||||
|
||||
# get the dimensions
|
||||
# - the + 1 for _e is to shift the boundary by one chunk
|
||||
# - this shifting is not needed if chunk_size divides e
|
||||
_s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0)
|
||||
|
||||
# adjust indices and offsets
|
||||
chunk_indices[_s:_e] -= p
|
||||
chunk_offsets[_s] = s % chunk_size
|
||||
|
||||
return chunk_indices, chunk_offsets
|
||||
|
||||
@staticmethod
|
||||
def prepare_decode(
|
||||
query_start_loc: torch.Tensor,
|
||||
mamba_cache_indices: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
) -> "Mamba2Metadata":
|
||||
"""This path is run during CUDA graph capture, i.e. decode only, so `num_prefills` is 0"""
|
||||
return Mamba2Metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
num_decodes=len(seq_lens),
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def prepare_mixed(
|
||||
cls,
|
||||
query_start_loc: torch.Tensor,
|
||||
mamba_cache_indices: torch.Tensor,
|
||||
chunk_size: int,
|
||||
forward_batch: ForwardBatch,
|
||||
) -> "Mamba2Metadata":
|
||||
"""This path cannot run with CUDA graph, as it contains extend requests."""
|
||||
if forward_batch.extend_num_tokens is None:
|
||||
return cls.prepare_decode(
|
||||
query_start_loc, mamba_cache_indices, forward_batch.seq_lens
|
||||
)
|
||||
num_prefills = len(forward_batch.extend_seq_lens)
|
||||
num_prefill_tokens = forward_batch.extend_num_tokens
|
||||
num_decodes = len(forward_batch.seq_lens) - num_prefills
|
||||
context_lens_tensor = forward_batch.extend_prefix_lens
|
||||
assert context_lens_tensor is not None
|
||||
# precompute flag to avoid device syncs later
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
prep_initial_states = torch.any(has_initial_states[:num_prefills]).item()
|
||||
|
||||
query_start_loc = query_start_loc[: num_prefills + 1]
|
||||
seq_idx = torch.repeat_interleave(
|
||||
torch.arange(
|
||||
num_prefills, dtype=torch.int32, device=query_start_loc.device
|
||||
),
|
||||
query_start_loc.diff(),
|
||||
output_size=num_prefill_tokens,
|
||||
)
|
||||
seq_idx.unsqueeze_(0)
|
||||
|
||||
# We compute metadata for chunked prefill once at the top level model
|
||||
# forward and reuse them in mamba layers. If not needed, they will be
|
||||
# ignored inside mamba kernels.
|
||||
chunk_offsets, chunk_indices = None, None
|
||||
if prep_initial_states:
|
||||
chunk_indices, chunk_offsets = (
|
||||
cls._query_start_loc_to_chunk_indices_offsets(
|
||||
query_start_loc, chunk_size, num_prefill_tokens
|
||||
)
|
||||
)
|
||||
|
||||
return Mamba2Metadata(
|
||||
query_start_loc=query_start_loc,
|
||||
mamba_cache_indices=mamba_cache_indices,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
num_decodes=num_decodes,
|
||||
mixed_metadata=cls.MixedMetadata(
|
||||
has_initial_states=has_initial_states,
|
||||
prep_initial_states=prep_initial_states,
|
||||
chunk_size=chunk_size,
|
||||
seq_idx=seq_idx,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
|
||||
),
|
||||
)
|
||||
@@ -1,81 +0,0 @@
|
||||
# Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
|
||||
from sglang.srt.distributed.utils import divide
|
||||
|
||||
|
||||
class MambaStateShapeCalculator:
|
||||
|
||||
@classmethod
|
||||
def linear_attention_state_shape(
|
||||
cls,
|
||||
num_heads: int,
|
||||
tp_size: int,
|
||||
head_dim: int,
|
||||
) -> tuple[tuple[int, int, int], ...]:
|
||||
|
||||
state_shape = (num_heads // tp_size, head_dim, head_dim)
|
||||
return (state_shape,)
|
||||
|
||||
@classmethod
|
||||
def mamba1_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int]]:
|
||||
conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
|
||||
|
||||
temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
|
||||
|
||||
conv_state_shape = conv_state_shape[1], conv_state_shape[0]
|
||||
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def mamba2_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
n_groups: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
state_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int], tuple[int, int, int]]:
|
||||
# if n_groups is not divisible by world_size, need to extend the shards
|
||||
# to ensure all groups needed by a head is sharded along with it
|
||||
n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
|
||||
# heads and n_groups are TP-ed
|
||||
conv_dim = intermediate_size + 2 * n_groups * state_size
|
||||
|
||||
# contiguous along 'dim' axis
|
||||
conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
|
||||
|
||||
# These are not TP-ed as they depend on A, dt_bias, D
|
||||
# - they are typically small
|
||||
# e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
|
||||
temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
|
||||
return conv_state_shape, temporal_state_shape
|
||||
|
||||
@classmethod
|
||||
def short_conv_state_shape(
|
||||
cls,
|
||||
tp_world_size: int,
|
||||
intermediate_size: int,
|
||||
conv_kernel: int,
|
||||
) -> tuple[tuple[int, int]]:
|
||||
conv_dim = divide(intermediate_size, tp_world_size)
|
||||
conv_state_shape = (conv_kernel - 1, conv_dim)
|
||||
return (conv_state_shape,)
|
||||
|
||||
@classmethod
|
||||
def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
|
||||
"""Compute the increase in group numbers to account for
|
||||
replication in order to accompany the head shards."""
|
||||
|
||||
# in the case ngoups % tp_size == 0, this will be zero
|
||||
if ngroups % tp_size == 0:
|
||||
return 0
|
||||
|
||||
# for n_groups == 1, this is exactly tp_size - n_groups
|
||||
return tp_size - ngroups
|
||||
@@ -0,0 +1,120 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.attention.fla.layernorm_gated import rms_norm_gated
|
||||
from sglang.srt.model_loader.weight_utils import sharded_weight_loader
|
||||
from sglang.srt.utils.common import set_weight_attrs
|
||||
|
||||
|
||||
class Mixer2RMSNormGated(CustomOp):
|
||||
def __init__(
|
||||
self,
|
||||
full_hidden_size: int,
|
||||
full_n_groups: int,
|
||||
use_rms_norm: bool = True,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.full_hidden_size = full_hidden_size
|
||||
self.group_size = full_hidden_size // full_n_groups
|
||||
self.per_rank_hidden_size = full_hidden_size // self.tp_size
|
||||
self.n_groups = full_hidden_size // self.group_size
|
||||
|
||||
self.variance_epsilon = eps
|
||||
self.use_rms_norm = use_rms_norm
|
||||
if self.use_rms_norm:
|
||||
# Register norm weight only if we're actually applying RMSNorm
|
||||
self.weight = torch.nn.Parameter(torch.ones(self.per_rank_hidden_size))
|
||||
set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
|
||||
else:
|
||||
# Avoid checkpoint mismatch by skipping unused parameter
|
||||
self.register_parameter("weight", None)
|
||||
assert (
|
||||
self.full_hidden_size % self.tp_size == 0
|
||||
), "Tensor parallel world size must divide hidden size."
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
):
|
||||
# Three tensor-parallel cases:
|
||||
# 1. n_groups is 1
|
||||
# In this case we parallelize along the reduction dim.
|
||||
# Each rank computes a local sum of squares followed by AllReduce
|
||||
# 2. tp_size divides n_groups
|
||||
# Each rank only reduces within its local group(s).
|
||||
# No collective ops necessary.
|
||||
# 3. The general case can be pretty complicated so we AllGather
|
||||
# the input and then redundantly compute the RMSNorm.
|
||||
input_dtype = x.dtype
|
||||
x = x * torch.nn.functional.silu(gate.to(torch.float32))
|
||||
if not self.use_rms_norm:
|
||||
return x.to(input_dtype)
|
||||
|
||||
if self.n_groups == 1:
|
||||
if self.tp_size > 1:
|
||||
# Compute local sum and then reduce to obtain global sum
|
||||
local_sums = x.pow(2).sum(dim=-1, keepdim=True)
|
||||
global_sums = tensor_model_parallel_all_reduce(local_sums)
|
||||
# Calculate the variance
|
||||
count = self.tp_size * x.shape[-1]
|
||||
variance = global_sums / count
|
||||
|
||||
else:
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
||||
else:
|
||||
redundant_tp: bool = self.n_groups % self.tp_size != 0
|
||||
if redundant_tp:
|
||||
# To handle the general case, redundantly apply the variance
|
||||
x = tensor_model_parallel_all_gather(x, -1)
|
||||
|
||||
*prefix_dims, hidden_dim = x.shape
|
||||
group_count = hidden_dim // self.group_size
|
||||
x_grouped = x.view(*prefix_dims, group_count, self.group_size)
|
||||
variance = x_grouped.pow(2).mean(-1, keepdim=True)
|
||||
x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
|
||||
x = x_grouped.view(*prefix_dims, hidden_dim)
|
||||
|
||||
if redundant_tp:
|
||||
start = self.per_rank_hidden_size * self.tp_rank
|
||||
end = start + self.per_rank_hidden_size
|
||||
x = x[..., start:end]
|
||||
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
gate: torch.Tensor,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
input_dtype = x.dtype
|
||||
if not self.use_rms_norm:
|
||||
# Keep gate in float32 for numerical stability during silu
|
||||
return x * torch.nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
|
||||
|
||||
if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
|
||||
return self.forward_native(x, gate)
|
||||
|
||||
return rms_norm_gated(
|
||||
x=x,
|
||||
weight=self.weight.data,
|
||||
bias=None,
|
||||
z=gate,
|
||||
eps=self.variance_epsilon,
|
||||
norm_before_gate=False,
|
||||
is_rms_norm=True,
|
||||
)
|
||||
@@ -15,56 +15,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["chunk_size", "K", "IS_CAUSAL"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _bmm_chunk_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -16,66 +16,6 @@ from packaging import version
|
||||
TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0")
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_scan_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -17,17 +17,6 @@ import triton.language as tl
|
||||
from .mamba_ssm import softplus
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_SIZE_H": 2}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 4}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 8}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 16}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 32}),
|
||||
# triton.Config({"BLOCK_SIZE_H": 64}),
|
||||
# ],
|
||||
# key=["chunk_size", "nheads"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_cumsum_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -120,56 +109,6 @@ def _chunk_cumsum_fwd_kernel(
|
||||
)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["hdim", "dstate", "chunk_size"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_state_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
@@ -320,56 +259,6 @@ def _chunk_state_fwd_kernel(
|
||||
tl.store(states_ptrs, states, mask=c_mask)
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64},
|
||||
# num_stages=3,
|
||||
# num_warps=8,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=4,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=5,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# triton.Config(
|
||||
# {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32},
|
||||
# num_stages=4,
|
||||
# num_warps=2,
|
||||
# ),
|
||||
# ],
|
||||
# key=["hdim", "dstate", "chunk_size"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _chunk_state_varlen_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -13,17 +13,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({"BLOCK_SIZE": 64}),
|
||||
# triton.Config({"BLOCK_SIZE": 128}),
|
||||
# triton.Config({"BLOCK_SIZE": 256}),
|
||||
# triton.Config({"BLOCK_SIZE": 512}),
|
||||
# triton.Config({"BLOCK_SIZE": 1024}),
|
||||
# triton.Config({"BLOCK_SIZE": 2048}),
|
||||
# ],
|
||||
# key=["dim"],
|
||||
# )
|
||||
@triton.jit
|
||||
def _state_passing_fwd_kernel(
|
||||
# Pointers to matrices
|
||||
|
||||
@@ -85,7 +85,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
||||
get_attention_tp_size()
|
||||
)
|
||||
if model_runner.is_hybrid_gdn:
|
||||
if model_runner.hybrid_gdn_config is not None:
|
||||
# For hybrid linear models, layer_id = 0 may not be full attention
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user