diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 9730df726..a676573f2 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -13,7 +13,7 @@ from sglang.srt.layers.attention.fla.fused_recurrent import ( from sglang.srt.layers.attention.fla.fused_sigmoid_gating_recurrent import ( fused_sigmoid_gating_delta_rule_update, ) -from sglang.srt.layers.attention.mamba.causal_conv1d import ( +from sglang.srt.layers.attention.mamba.causal_conv1d_triton import ( causal_conv1d_fn, causal_conv1d_update, ) @@ -195,7 +195,9 @@ class MambaAttnBackend(AttentionBackend): dt_bias = kwargs["dt_bias"] layer_id = kwargs["layer_id"] - conv_states, ssm_states = self.req_to_token_pool.get_mamba_params(layer_id) + conv_states, ssm_states, *rest = self.req_to_token_pool.get_mamba_params( + layer_id + ) query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices @@ -277,12 +279,9 @@ class MambaAttnBackend(AttentionBackend): ( conv_states, ssm_states, - mixed_qkv_cache, intermediate_state_cache, + intermediate_conv_window_cache, ) = self.req_to_token_pool.get_mamba_params(layer_id) - mixed_qkv_cache[cache_indices] = mixed_qkv.view( - (-1,) + mixed_qkv_cache.shape[1:] - ).clone() has_initial_states = torch.ones( seq_len // forward_batch.spec_info.draft_token_num, dtype=torch.bool, @@ -295,16 +294,38 @@ class MambaAttnBackend(AttentionBackend): ) has_initial_states = forward_batch.extend_prefix_lens > 0 conv_states_to_use = conv_states - mixed_qkv = causal_conv1d_fn( - mixed_qkv.transpose(0, 1), - conv_weights, - bias, - activation=activation, - conv_states=conv_states_to_use, - has_initial_state=has_initial_states, - cache_indices=cache_indices, - query_start_loc=query_start_loc, - ).transpose(0, 1)[:seq_len] + + if is_target_verify: + batch_size = seq_len // forward_batch.spec_info.draft_token_num + draft_token_num = forward_batch.spec_info.draft_token_num + mixed_qkv_reshaped = ( + mixed_qkv.view(batch_size, draft_token_num, -1) + .transpose(1, 2) + .contiguous() + ) + mixed_qkv_processed = causal_conv1d_update( + mixed_qkv_reshaped, + conv_states_to_use, + conv_weights, + bias, + activation, + conv_state_indices=cache_indices[:batch_size], + intermediate_conv_window=intermediate_conv_window_cache, + ) + mixed_qkv = ( + mixed_qkv_processed.transpose(1, 2).contiguous().view(seq_len, -1) + ) + else: + mixed_qkv = causal_conv1d_fn( + mixed_qkv.transpose(0, 1), + conv_weights, + bias, + activation=activation, + conv_states=conv_states_to_use, + has_initial_state=has_initial_states, + cache_indices=cache_indices, + query_start_loc=query_start_loc, + ).transpose(0, 1)[:seq_len] key_split_dim = key_dim // attn_tp_size value_split_dim = value_dim // attn_tp_size @@ -507,26 +528,6 @@ class HybridLinearAttnBackend(AttentionBackend): def update_mamba_state_after_mtp_verify(self, accepted_length, model): request_number = accepted_length.shape[0] - # QQ: step = spec num_draft token num - num_draft_tokens = ( - self.attn_backend_list[1] - .req_to_token_pool.mamba_pool.mamba_cache[2] - .shape[2] - ) - query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype) - query_start_loc = torch.cat( - [ - torch.zeros( - 1, - dtype=query_start_loc.dtype, - device=query_start_loc.device, - ), - query_start_loc, - ] - ) - mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze( - 0 - ) < accepted_length.unsqueeze(1) state_indices_tensor = self.attn_backend_list[ 1 @@ -536,46 +537,48 @@ class HybridLinearAttnBackend(AttentionBackend): 1 ].req_to_token_pool.get_mamba_params_all_layers() - conv_states, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches + ( + conv_states, + ssm_states, + intermediate_state_cache, + intermediate_conv_window_cache, + ) = mamba_caches - mixed_qkvs = mix_qkv_cache[:, state_indices_tensor][:, mask] - - mamba_map = self.attn_backend_list[1].req_to_token_pool.mamba_map - - has_initial_states = torch.ones( - request_number, dtype=torch.bool, device=accepted_length.device - ) - - # Batch SSM state updates (outside the loop for efficiency) + # SSM state updates (chunked to reduce peak memory) valid_mask = accepted_length > 0 - if intermediate_state_cache is not None: - last_steps = (accepted_length - 1).to(torch.int64) - valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) - ssm_states[:, valid_state_indices, :] = intermediate_state_cache[ - :, valid_state_indices, last_steps - ].to(ssm_states.dtype) + # Compute common indices once to avoid duplication + last_steps_all = (accepted_length - 1).to(torch.int64) + valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) + last_steps = last_steps_all[valid_mask].to(torch.int64) - # For loop conv state updates (can be optimized) - for i in range(len(model.model.layers)): - layer = model.model.layers[i] - if isinstance(layer, Qwen3HybridLinearDecoderLayer): - conv_weights = layer.linear_attn.conv1d.weight.view( - layer.linear_attn.conv1d.weight.size(0), - layer.linear_attn.conv1d.weight.size(2), - ) + if valid_state_indices.numel() > 0: + chunk = 256 + num_valid = valid_state_indices.numel() - layer_id = mamba_map[i] - conv_state = conv_states[layer_id] - mixed_qkv = mixed_qkvs[layer_id] + # SSM state updates + for i in range(0, num_valid, chunk): + idx = valid_state_indices[i : i + chunk] + steps = last_steps[i : i + chunk] + # per (cache line, step) + for j in range(idx.numel()): + ci = idx[j].item() + st = steps[j].item() + ssm_states[:, ci, :].copy_( + intermediate_state_cache[:, ci, st].to( + ssm_states.dtype, copy=False + ) + ) - _ = causal_conv1d_fn( - mixed_qkv.transpose(0, 1), - conv_weights, - layer.linear_attn.conv1d.bias, - activation=layer.linear_attn.activation, - conv_states=conv_state, - has_initial_state=has_initial_states, - cache_indices=state_indices_tensor, - query_start_loc=query_start_loc, - ) + # Conv window updates + for i in range(0, num_valid, chunk): + idx = valid_state_indices[i : i + chunk] + steps = last_steps[i : i + chunk] + for j in range(idx.numel()): + ci = idx[j].item() + st = steps[j].item() + conv_states[:, ci, :, :].copy_( + intermediate_conv_window_cache[:, ci, st].to( + conv_states.dtype, copy=False + ) + ) diff --git a/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py new file mode 100644 index 000000000..3c1bdec48 --- /dev/null +++ b/python/sglang/srt/layers/attention/mamba/causal_conv1d_triton.py @@ -0,0 +1,1052 @@ +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +# and https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/ops/causal_conv1d.py + +from typing import Optional, Union + +import numpy as np +import torch + +PAD_SLOT_ID = -1 +import triton +import triton.language as tl + + +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + batch_ptr, + token_chunk_offset_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + batch: tl.int32, # actually padded_batch + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = ( + KERNEL_WIDTH - 1 + ) # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.load(batch_ptr + tl.program_id(0)) + chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + # base of the sequence + x_base = ( + x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim + ) # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 5: # STRATEGY1 + col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if ( + state_len <= seqlen + ): # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange( + 0, NP2_STATELEN + ) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] + & (idx_tokens_last < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = ( + conv_states_base[None, :] + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + + tl.debug_barrier() # need this due to the bug in tl.where not enforcing this when data is the result of another tl.load + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = ( + x_base[None, :] + + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = ( + conv_states_base + + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[ + None, : + ] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (sequence_start_index + token_offset + idx_token) * stride_o_token + + (idx_feats * stride_o_dim) + ) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch + sequences are concatenated from left to right for varlen + weight: (dim, width) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + out: same shape as `x` + """ + if isinstance(activation, bool) and activation: + activation = "silu" + + args = None + out = torch.empty_like(x) + if metadata is not None: + cu_seqlen = metadata.cu_seqlen + nums_dict = metadata.nums_dict + # x = metadata.x + args = nums_dict + batch_ptr = metadata.batch_ptr + token_chunk_offset_ptr = metadata.token_chunk_offset_ptr + else: + seqlens = np.diff(query_start_loc.to("cpu")) + args = seqlens + MAX_NUM_PROGRAMS = 1024 + + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device + ) # tracking which seq-idx the Triton program is handling + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device + ) # tracking BLOCK_M-based index in the sequence the Triton program is handling + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + padded_batch = query_start_loc.size(0) - 1 + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ) + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + # assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch,) + assert ( + conv_states is not None + ), "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + if metadata is None: + + def num_program(META, seqlens): + tot = 0 + + mlist = [] + offsetlist = [] # type: ignore + + nums = -(-seqlens // META["BLOCK_M"]) + + tot = nums.sum().item() + mlist = np.repeat(np.arange(len(nums)), nums) + for idx, num in enumerate(nums): + offsetlist.extend( + range(num) + ) # chunk-idx if a sequence is split into multiple chunks + + if META["batch_ptr"].nelement() < len(mlist): + newlen = len(mlist) + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= len(mlist): + META["batch_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(mlist)) + ) + META["token_chunk_offset_ptr"][0 : len(mlist)].copy_( + torch.from_numpy(np.array(offsetlist)) + ) + + META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) + META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to( + META["x_ptr"].device + ) + return tot + + else: + + def num_program(META, nums_dict): + tot = nums_dict[META["BLOCK_M"]]["tot"] + + mlist = nums_dict[META["BLOCK_M"]]["mlist"] + mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] + + offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] + + if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: + META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]][ + "token_chunk_offset_ptr" + ] + else: + if META["batch_ptr"].nelement() < mlist_len: + newlen = mlist_len + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= mlist_len: + META["batch_ptr"][0:mlist_len].copy_(mlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) + return tot + + def grid(META): + return ( + num_program(META, args), + triton.cdiv(dim, META["BLOCK_N"]), + ) + + if batch_ptr.device != x.device: + batch_ptr = batch_ptr.to(x.device) + token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + batch_ptr, + token_chunk_offset_ptr, + out, + # Matrix dimensions + padded_batch, + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + # launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + intermediate_conv_window_ptr, + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_inter_seq: tl.constexpr, + stride_inter_step: tl.constexpr, + stride_inter_dim: tl.constexpr, + stride_inter_win: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, + SAVE_INTERMEDIATE: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load( + conv_state_indices_ptr + idx_seq * stride_state_indices + ).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # The conv_state updates works in a sliding window manner, + # at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + 1) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + (idx_seq * stride_x_seq) + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = ( + x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] + ) # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] + & (idx_tokens - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = ( + conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to( + tl.float32 + ) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.static_range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & ( + idx_feats < dim + ) # token-index # feature-index + o_ptrs = ( + o_ptr + + (idx_seq) * stride_o_seq + + idx_token * stride_o_token + + (idx_feats * stride_o_dim) + ) + + tl.store(o_ptrs, acc, mask=mask_1d) + + if SAVE_INTERMEDIATE: + # Save the window state after consuming this token + # Layout: [seq(cache line), step, dim, win(K-1)] + base_ptr = ( + intermediate_conv_window_ptr + + conv_state_batch_coord * stride_inter_seq + + idx_token * stride_inter_step + + idx_feats * stride_inter_dim + ) + if KERNEL_WIDTH >= 2: + tl.store(base_ptr + 0 * stride_inter_win, col0, mask=mask_w) + if KERNEL_WIDTH >= 3: + tl.store(base_ptr + 1 * stride_inter_win, col1, mask=mask_w) + if KERNEL_WIDTH >= 4: + tl.store(base_ptr + 2 * stride_inter_win, col2, mask=mask_w) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + intermediate_conv_window: Optional[torch.Tensor] = None, + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=False, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + [shape=2: single token prediction] + [shape=3: single or multiple tokens prediction] + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch,) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + stride_x_seq, stride_x_dim, stride_x_token = x.stride() # X (batch, dim, seqlen) + + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = ( + conv_state_indices.stride(0) if conv_state_indices is not None else 0 + ) + state_len = width - 1 + (seqlen - 1) # effective state_len needed + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + # prepare intermediate buffer strides if provided + if intermediate_conv_window is not None: + stride_inter_seq, stride_inter_step, stride_inter_dim, stride_inter_win = ( + intermediate_conv_window.stride(0), + intermediate_conv_window.stride(1), + intermediate_conv_window.stride(2), + intermediate_conv_window.stride(3), + ) + else: + stride_inter_seq = stride_inter_step = stride_inter_dim = stride_inter_win = 0 + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + intermediate_conv_window if intermediate_conv_window is not None else x, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_inter_seq, + stride_inter_step, + stride_inter_dim, + stride_inter_win, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + SAVE_INTERMEDIATE=intermediate_conv_window is not None, + ) + if unsqueeze: + out = out.squeeze(-1) + return out diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 6cc66ba1a..80c549033 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -125,16 +125,6 @@ class MambaPool: device=device, ) if speculative_num_draft_tokens is not None: - mixed_qkv_cache = torch.empty( - size=( - num_mamba_layers, - size + 1, - speculative_num_draft_tokens, - conv_state_shape[0], - ), - dtype=conv_dtype, - device="cuda", - ) # Cache intermediate SSM states per draft token during target verify # Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V] intermediate_ssm_state_cache = torch.empty( @@ -149,11 +139,24 @@ class MambaPool: dtype=ssm_dtype, device="cuda", ) + # Cache intermediate conv windows (last K-1 inputs) per draft token during target verify + # Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1] + intermediate_conv_window_cache = torch.empty( + size=( + num_mamba_layers, + size + 1, + speculative_num_draft_tokens, + conv_state_shape[0], + conv_state_shape[1], + ), + dtype=conv_dtype, + device="cuda", + ) self.mamba_cache = ( conv_state, temporal_state, - mixed_qkv_cache, intermediate_ssm_state_cache, + intermediate_conv_window_cache, ) else: self.mamba_cache = (conv_state, temporal_state) diff --git a/python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py deleted file mode 100644 index bf8d462aa..000000000 --- a/python/sglang/srt/speculative/eagle_target_verify_cuda_graph_runner.py +++ /dev/null @@ -1,195 +0,0 @@ -import bisect -from typing import TYPE_CHECKING, Callable - -import torch -import torch.nn.functional as F - -from sglang.srt.layers.attention.fla.fused_recurrent import ( - fused_recurrent_gated_delta_rule_update, -) -from sglang.srt.layers.attention.mamba.causal_conv1d import causal_conv1d_fn -from sglang.srt.model_executor.cuda_graph_runner import ( - CUDA_GRAPH_CAPTURE_FAILED_MSG, - CudaGraphRunner, - get_batch_sizes_to_capture, - get_global_graph_memory_pool, - model_capture_mode, - set_global_graph_memory_pool, -) -from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer - -if TYPE_CHECKING: - from sglang.srt.speculative.eagle_worker import EAGLEWorker - - -class MambaStateUpdateCudaGraphRunner: - def __init__(self, eagle_worker: "EAGLEWorker"): - self.eagle_worker = eagle_worker - model_runner = eagle_worker.target_worker.model_runner - self.model_runner = model_runner - self.attn_backend = model_runner.attn_backend.attn_backend_list[1] - self.req_to_token_pool = self.attn_backend.req_to_token_pool - - self.graphs = {} - self.output_buffers = {} - self.graph_input_buffer = None - self.stream = torch.cuda.Stream() - self.model = model_runner.model - - self.enable_profile_cuda_graph = ( - model_runner.server_args.enable_profile_cuda_graph - ) - self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) - self.max_bs = self.capture_bs[-1] - - self.init_cuda_graph_state() - # Capture - try: - with model_capture_mode(): - self.capture() - except RuntimeError as e: - raise Exception( - f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" - ) - - def init_cuda_graph_state(self): - self.mamba_cache = self.req_to_token_pool.mamba_pool.mamba_cache - self.num_tokens_per_bs = self.max_accepted_tokens = self.mamba_cache[2].shape[2] - num_mamba_layers = self.mamba_cache[0].shape[0] - conv_dtype = torch.bfloat16 - conv_shape = self.mamba_cache[0].shape[2] - total_token_number = self.max_accepted_tokens * self.max_bs - self.mixed_qkv_cache = torch.empty( - size=( - num_mamba_layers, - total_token_number, - conv_shape, - ), - dtype=conv_dtype, - device="cuda", - ) - self.query_start_loc = torch.zeros( - (self.max_bs + 1,), dtype=torch.int32, device="cuda" - ) - self.state_indices = torch.zeros( - (self.max_bs + 1,), dtype=torch.int32, device="cuda" - ) - self.has_initial_states = torch.ones( - self.max_bs, dtype=torch.bool, device="cuda" - ) - - def capture(self): - CudaGraphRunner.capture(self) - - def capture_one_batch_size(self, bs: int, forward: Callable): - """ - Capture CUDA Graph for a typical workload - """ - graph = torch.cuda.CUDAGraph() - stream = self.stream - total_token_number = bs * self.max_accepted_tokens - mixed_qkvs = self.mixed_qkv_cache[:, :total_token_number] - - query_start_loc = self.query_start_loc[: bs + 1] - state_indices = self.state_indices[:bs] - has_initial_states = self.has_initial_states[:bs] - - mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers() - conv_states = mamba_caches[0] - mamba_map = self.req_to_token_pool.mamba_map - - def run_once(): - for i in range(len(self.model.model.layers)): - layer = self.model.model.layers[i] - if not isinstance(layer, Qwen3HybridLinearDecoderLayer): - continue - conv_weights = layer.linear_attn.conv1d.weight.view( - layer.linear_attn.conv1d.weight.size(0), - layer.linear_attn.conv1d.weight.size(2), - ) - layer_id = mamba_map[i] - - causal_conv1d_fn( - mixed_qkvs[layer_id].transpose(0, 1), - conv_weights, - layer.linear_attn.conv1d.bias, - activation=layer.linear_attn.activation, - conv_states=conv_states[layer_id], - has_initial_state=has_initial_states, - cache_indices=state_indices, - query_start_loc=query_start_loc, - ) - - return None - - for _ in range(2): - torch.cuda.synchronize() - self.model_runner.tp_group.barrier() - - run_once() - - with torch.cuda.graph( - graph, pool=get_global_graph_memory_pool(), stream=stream - ): - out = run_once() - - set_global_graph_memory_pool(graph.pool()) - return graph, out - - def can_run(self, accepted_length): - bs = accepted_length.shape[0] - return bs <= self.max_bs - - def replay_repare(self, accepted_length): - request_number = accepted_length.shape[0] - # QQ: step = spec num_draft token num - num_draft_tokens = self.req_to_token_pool.mamba_pool.mamba_cache[2].shape[2] - query_start_loc = accepted_length.cumsum(-1, dtype=accepted_length.dtype) - query_start_loc = torch.cat( - [ - torch.zeros( - 1, - dtype=query_start_loc.dtype, - device=query_start_loc.device, - ), - query_start_loc, - ] - ) - mask = torch.arange(num_draft_tokens, device=accepted_length.device).unsqueeze( - 0 - ) < accepted_length.unsqueeze(1) - - state_indices_tensor = self.attn_backend.forward_metadata.mamba_cache_indices[ - :request_number - ] - mamba_caches = self.req_to_token_pool.get_mamba_params_all_layers() - - _, ssm_states, mix_qkv_cache, intermediate_state_cache = mamba_caches - mixed_qkvs = mamba_caches[2][:, state_indices_tensor][:, mask] - self.mixed_qkv_cache[:, : mixed_qkvs.shape[1]].copy_(mixed_qkvs) - self.query_start_loc[: request_number + 1] = query_start_loc - self.query_start_loc[request_number + 1 :] = self.query_start_loc[ - request_number - ] - self.state_indices[:request_number] = state_indices_tensor - self.state_indices[request_number:] = -1 - valid_mask = accepted_length > 0 - if intermediate_state_cache is not None: - last_steps = (accepted_length - 1).to(torch.int64) - valid_state_indices = state_indices_tensor[valid_mask].to(torch.int64) - - ssm_states[:, valid_state_indices, :] = intermediate_state_cache[ - :, valid_state_indices, last_steps - ].to(ssm_states.dtype) - - def replay(self, accepted_length): - # batch_size and num_seqs can be different in case there are finished examples - # in the batch, which will not be counted as num_seqs - raw_bs = accepted_length.shape[0] - index = bisect.bisect_left(self.capture_bs, raw_bs) - - bs = self.capture_bs[index] - - self.replay_repare(accepted_length) - # Replay - self.graphs[bs].replay() diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 3ec32a0a2..f454971ca 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -407,15 +407,6 @@ class EAGLEWorker(TpModelWorker): f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB." ) - if self.target_worker.model_runner.is_hybrid_gdn: - from sglang.srt.speculative.eagle_target_verify_cuda_graph_runner import ( - MambaStateUpdateCudaGraphRunner, - ) - - self.cuda_graph_runner_for_target_verify = MambaStateUpdateCudaGraphRunner( - self - ) - @property def draft_model_runner(self): return self.model_runner @@ -848,12 +839,9 @@ class EAGLEWorker(TpModelWorker): ) + 1 ) - if self.cuda_graph_runner_for_target_verify.can_run(accepted_length): - self.cuda_graph_runner_for_target_verify.replay(accepted_length) - else: - self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( - accepted_length, self.target_worker.model_runner.model - ) + self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify( + accepted_length, self.target_worker.model_runner.model + ) if batch.return_logprob: self.add_logprob_values(batch, res, logits_output)