diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 19e9bd516..86dd77f37 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -362,6 +362,7 @@ class MambaAttnBackend(AttentionBackend): has_initial_state=has_initial_states, cache_indices=cache_indices, query_start_loc=query_start_loc, + seq_lens_cpu=forward_batch.extend_seq_lens_cpu, ).transpose(0, 1)[:seq_len] key_split_dim = key_dim // attn_tp_size diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d.py index d004337ff..d9f63641d 100644 --- a/python/sglang/srt/layers/attention/mamba/causal_conv1d.py +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d.py @@ -23,6 +23,7 @@ def causal_conv1d_fn( conv_states: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, + **kwargs, ): """ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py index 3c1bdec48..8c9d8bd7b 100644 --- a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py @@ -2,7 +2,7 @@ # Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py # and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py -from typing import Optional, Union +from typing import List, Optional, Union import numpy as np import torch @@ -22,11 +22,8 @@ def _causal_conv1d_fwd_kernel( # continuous batching cache_indices_ptr, # conv_state_indices_ptr has_initial_states_ptr, query_start_loc_ptr, - batch_ptr, - token_chunk_offset_ptr, o_ptr, # (dim, seqlen) - actually pointing to x_ptr # Matrix dimensions - batch: tl.int32, # actually padded_batch dim: tl.constexpr, seqlen: tl.int32, # cu_seqlen num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines @@ -69,11 +66,11 @@ def _causal_conv1d_fwd_kernel( # continuous batching # rather than mixing sequences - to make updating initial_states across sequences efficiently # single-sequence id - idx_seq = tl.load(batch_ptr + tl.program_id(0)) - chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + idx_seq = tl.program_id(0) + chunk_offset = tl.program_id(1) # BLOCK_N elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + idx_feats = tl.program_id(2) * BLOCK_N + tl.arange(0, BLOCK_N) if idx_seq == pad_slot_id: return @@ -86,6 +83,9 @@ def _causal_conv1d_fwd_kernel( # continuous batching token_offset = BLOCK_M * chunk_offset segment_len = min(BLOCK_M, seqlen - token_offset) + if segment_len <= 0: + return + # base of the sequence x_base = ( x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim @@ -382,12 +382,13 @@ def causal_conv1d_fn( bias: Union[torch.Tensor, None], conv_states: torch.Tensor, query_start_loc: torch.Tensor, + seq_lens_cpu: List[int], cache_indices: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", pad_slot_id: int = PAD_SLOT_ID, - metadata=None, validate_data=False, + **kwargs, ): """support varlen + continuous batching when x is 2D tensor @@ -413,6 +414,8 @@ def causal_conv1d_fn( [length(query_start_loc)-1 == batch] for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) + seq_lens_cpu: (batch) int32 + The sequence lengths of the sequences in the batch cache_indices: (batch) int32 indicates the corresponding state index, like so: conv_state = conv_states[cache_indices[batch_id]] @@ -434,26 +437,7 @@ def causal_conv1d_fn( if isinstance(activation, bool) and activation: activation = "silu" - args = None out = torch.empty_like(x) - if metadata is not None: - cu_seqlen = metadata.cu_seqlen - nums_dict = metadata.nums_dict - # x = metadata.x - args = nums_dict - batch_ptr = metadata.batch_ptr - token_chunk_offset_ptr = metadata.token_chunk_offset_ptr - else: - seqlens = np.diff(query_start_loc.to("cpu")) - args = seqlens - MAX_NUM_PROGRAMS = 1024 - - batch_ptr = torch.full( - (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device - ) # tracking which seq-idx the Triton program is handling - token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device - ) # tracking BLOCK_M-based index in the sequence the Triton program is handling is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) dim, cu_seqlen = x.shape @@ -461,7 +445,6 @@ def causal_conv1d_fn( state_len = width - 1 np2_statelen = triton.next_power_of_2(state_len) - padded_batch = query_start_loc.size(0) - 1 stride_x_seq = 0 stride_x_dim = x.stride(0) stride_x_token = x.stride(1) @@ -501,6 +484,7 @@ def causal_conv1d_fn( assert query_start_loc is not None assert query_start_loc.dim() == 1 assert x.stride(0) == 1 or x.stride(1) == 1 + padded_batch = query_start_loc.size(0) - 1 if bias is not None: assert bias.dim() == 1 assert dim == bias.size(0) @@ -516,78 +500,14 @@ def causal_conv1d_fn( assert (dim, width) == weight.shape assert is_channel_last, "Need to run in channel-last layout" - if metadata is None: - - def num_program(META, seqlens): - tot = 0 - - mlist = [] - offsetlist = [] # type: ignore - - nums = -(-seqlens // META["BLOCK_M"]) - - tot = nums.sum().item() - mlist = np.repeat(np.arange(len(nums)), nums) - for idx, num in enumerate(nums): - offsetlist.extend( - range(num) - ) # chunk-idx if a sequence is split into multiple chunks - - if META["batch_ptr"].nelement() < len(mlist): - newlen = len(mlist) + 1 - META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - - if META["batch_ptr"].nelement() >= len(mlist): - META["batch_ptr"][0 : len(mlist)].copy_( - torch.from_numpy(np.array(mlist)) - ) - META["token_chunk_offset_ptr"][0 : len(mlist)].copy_( - torch.from_numpy(np.array(offsetlist)) - ) - - META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) - META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( - META["x_ptr"].device - ) - return tot - - else: - - def num_program(META, nums_dict): - tot = nums_dict[META["BLOCK_M"]]["tot"] - - mlist = nums_dict[META["BLOCK_M"]]["mlist"] - mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] - - offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] - - if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: - META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] - META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][ - "token_chunk_offset_ptr" - ] - else: - if META["batch_ptr"].nelement() < mlist_len: - newlen = mlist_len + 1 - META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - - if META["batch_ptr"].nelement() >= mlist_len: - META["batch_ptr"][0:mlist_len].copy_(mlist) - META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) - return tot - def grid(META): + max_seq_len = max(seq_lens_cpu) return ( - num_program(META, args), + len(seq_lens_cpu), # batch_size + (max_seq_len + META["BLOCK_M"] - 1) // META["BLOCK_M"], triton.cdiv(dim, META["BLOCK_N"]), ) - if batch_ptr.device != x.device: - batch_ptr = batch_ptr.to(x.device) - token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) - _causal_conv1d_fwd_kernel[grid]( # Pointers to matrices x, @@ -597,11 +517,8 @@ def causal_conv1d_fn( cache_indices, has_initial_state, query_start_loc, - batch_ptr, - token_chunk_offset_ptr, out, # Matrix dimensions - padded_batch, dim, cu_seqlen, num_cache_lines,